Skip to content

Commit db04463

Browse files
bweedopivikash
authored andcommitted
feat(amazonq): user can generate unit tests for selected code #5577
- Add command to Amazon Q to generate unit tests for selected code. The feature will be limited to internal Amazon users, initially. - Add "Generate Tests" command.
1 parent 47b436f commit db04463

File tree

15 files changed

+175
-44
lines changed

15 files changed

+175
-44
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "Bug Fix",
3+
"description": "do not trigger uninstall event when the extension auto-updates"
4+
}

packages/amazonq/package.json

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,14 @@
349349
"command": "aws.amazonq.optimizeCode",
350350
"group": "cw_chat@4"
351351
},
352+
{
353+
"command": "aws.amazonq.generateUnitTests",
354+
"group": "cw_chat@5",
355+
"when": "aws.codewhisperer.connected && aws.isInternalUser"
356+
},
352357
{
353358
"command": "aws.amazonq.sendToPrompt",
354-
"group": "cw_chat@5"
359+
"group": "cw_chat@6"
355360
}
356361
],
357362
"editor/context": [
@@ -423,6 +428,12 @@
423428
"category": "%AWS.amazonq.title%",
424429
"enablement": "aws.codewhisperer.connected"
425430
},
431+
{
432+
"command": "aws.amazonq.generateUnitTests",
433+
"title": "%AWS.command.amazonq.generateUnitTests%",
434+
"category": "%AWS.amazonq.title%",
435+
"enablement": "aws.codewhisperer.connected && aws.isInternalUser"
436+
},
426437
{
427438
"command": "aws.amazonq.reconnect",
428439
"title": "%AWS.command.codewhisperer.reconnect%",
@@ -599,6 +610,13 @@
599610
"mac": "cmd+alt+q",
600611
"linux": "meta+alt+q"
601612
},
613+
{
614+
"command": "aws.amazonq.generateUnitTests",
615+
"key": "win+alt+t",
616+
"mac": "cmd+alt+t",
617+
"linux": "meta+alt+t",
618+
"when": "aws.codewhisperer.connected && aws.isInternalUser"
619+
},
602620
{
603621
"command": "aws.amazonq.invokeInlineCompletion",
604622
"key": "alt+c",

packages/core/package.nls.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
"AWS.command.amazonq.fixCode": "Fix",
112112
"AWS.command.amazonq.optimizeCode": "Optimize",
113113
"AWS.command.amazonq.sendToPrompt": "Send to prompt",
114+
"AWS.command.amazonq.generateUnitTests": "Generate Tests (Beta)",
114115
"AWS.command.amazonq.security.scan": "Run Project Scan",
115116
"AWS.command.deploySamApplication": "Deploy SAM Application",
116117
"AWS.command.aboutToolkit": "About",

packages/core/src/auth/auth.ts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,15 @@ import {
5959
AwsConnection,
6060
scopesCodeWhispererCore,
6161
ProfileNotFoundError,
62+
isSsoConnection,
6263
} from './connection'
6364
import { isSageMaker, isCloud9, isAmazonQ } from '../shared/extensionUtilities'
6465
import { telemetry } from '../shared/telemetry/telemetry'
6566
import { randomUUID } from '../shared/crypto'
6667
import { asStringifiedStack } from '../shared/telemetry/spans'
6768
import { withTelemetryContext } from '../shared/telemetry/util'
6869
import { DiskCacheError } from '../shared/utilities/cacheUtils'
70+
import { setContext } from '../shared/vscode/setContext'
6971

7072
interface AuthService {
7173
/**
@@ -166,6 +168,30 @@ export class Auth implements AuthService, ConnectionManager {
166168
return this.#ssoCacheWatcher
167169
}
168170

171+
public get startUrl(): string | undefined {
172+
return isSsoConnection(this.activeConnection)
173+
? this.normalizeStartUrl(this.activeConnection.startUrl)
174+
: undefined
175+
}
176+
177+
public isConnected(): boolean {
178+
return this.activeConnection !== undefined
179+
}
180+
181+
/**
182+
* Normalizes the provided URL
183+
*
184+
* Any trailing '/' and `#` is removed from the URL
185+
* e.g. https://view.awsapps.com/start/# will become https://view.awsapps.com/start
186+
*/
187+
public normalizeStartUrl(startUrl: string | undefined) {
188+
return !startUrl ? undefined : startUrl.replace(/[\/#]+$/g, '')
189+
}
190+
191+
public isInternalAmazonUser(): boolean {
192+
return this.isConnected() && this.startUrl === 'https://amzn.awsapps.com/start'
193+
}
194+
169195
/**
170196
* Map startUrl -> declared connections
171197
*/
@@ -223,6 +249,8 @@ export class Auth implements AuthService, ConnectionManager {
223249
this.#onDidChangeActiveConnection.fire(conn)
224250
await this.store.setCurrentProfileId(id)
225251

252+
await setContext('aws.isInternalUser', this.isInternalAmazonUser())
253+
226254
return conn
227255
}
228256

@@ -373,6 +401,7 @@ export class Auth implements AuthService, ConnectionManager {
373401
}
374402
}
375403
this.#onDidDeleteConnection.fire({ connId, storedProfile: profile })
404+
await setContext('aws.isInternalUser', false)
376405
}
377406

378407
@withTelemetryContext({ name: 'clearStaleLinkedIamConnections', class: authClassName })
@@ -405,6 +434,7 @@ export class Auth implements AuthService, ConnectionManager {
405434
await provider.invalidate('devModeManualExpiration')
406435
// updates the state of the connection
407436
await this.refreshConnectionState(conn)
437+
await setContext('aws.isInternalUser', false)
408438
}
409439

410440
public async getConnection(connection: Pick<Connection, 'id'>): Promise<Connection | undefined> {

packages/core/src/codewhispererChat/commands/registerCommands.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ export function registerCommands(controllerPublishers: ChatControllerMessagePubl
9494
})
9595
})
9696
})
97+
Commands.register('aws.amazonq.generateUnitTests', async (data) => {
98+
return focusAmazonQPanel.execute(placeholder, 'amazonq.generateUnitTests').then(() => {
99+
controllerPublishers.processContextMenuCommand.publish({
100+
type: 'aws.amazonq.generateUnitTests',
101+
triggerType: getCommandTriggerType(data),
102+
})
103+
})
104+
})
97105
}
98106

99107
export type EditorContextBaseCommandType =
@@ -102,6 +110,7 @@ export type EditorContextBaseCommandType =
102110
| 'aws.amazonq.fixCode'
103111
| 'aws.amazonq.optimizeCode'
104112
| 'aws.amazonq.sendToPrompt'
113+
| 'aws.amazonq.generateUnitTests'
105114

106115
export type CodeScanIssueCommandType = 'aws.amazonq.explainIssue'
107116

packages/core/src/codewhispererChat/controllers/chat/messenger/messenger.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ export class Messenger {
340340
['aws.amazonq.fixCode', 'Fix'],
341341
['aws.amazonq.optimizeCode', 'Optimize'],
342342
['aws.amazonq.sendToPrompt', 'Send to prompt'],
343+
['aws.amazonq.generateUnitTests', 'Generate unit tests for'],
343344
])
344345

345346
public sendStaticTextResponse(type: StaticTextResponseType, triggerID: string, tabID: string) {

packages/core/src/codewhispererChat/controllers/chat/prompts/promptsGenerator.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ export class PromptsGenerator {
1414
['aws.amazonq.fixCode', 'Fix'],
1515
['aws.amazonq.optimizeCode', 'Optimize'],
1616
['aws.amazonq.sendToPrompt', 'Send to prompt'],
17+
['aws.amazonq.generateUnitTests', 'Generate unit tests for'],
1718
])
1819

1920
public generateForContextMenuCommand(command: EditorContextCommand): string {

packages/core/src/codewhispererChat/controllers/chat/telemetryHelper.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ export class CWCTelemetryHelper {
8686
return 'explainLineByLine'
8787
case UserIntent.SHOW_EXAMPLES:
8888
return 'showExample'
89+
case UserIntent.GENERATE_UNIT_TESTS:
90+
return 'generateUnitTests'
8991
default:
9092
return undefined
9193
}

packages/core/src/codewhispererChat/controllers/chat/userIntent/userIntentRecognizer.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import { UserIntent } from '@amzn/codewhisperer-streaming'
77
import { EditorContextCommand } from '../../../commands/registerCommands'
88
import { PromptMessage } from '../model'
9+
import { Auth } from '../../../../auth'
910

1011
export class UserIntentRecognizer {
1112
public getFromContextMenuCommand(command: EditorContextCommand): UserIntent | undefined {
@@ -18,6 +19,8 @@ export class UserIntentRecognizer {
1819
return UserIntent.APPLY_COMMON_BEST_PRACTICES
1920
case 'aws.amazonq.optimizeCode':
2021
return UserIntent.IMPROVE_CODE
22+
case 'aws.amazonq.generateUnitTests':
23+
return UserIntent.GENERATE_UNIT_TESTS
2124
default:
2225
return undefined
2326
}
@@ -36,6 +39,8 @@ export class UserIntentRecognizer {
3639
return UserIntent.APPLY_COMMON_BEST_PRACTICES
3740
} else if (prompt.message.startsWith('Optimize')) {
3841
return UserIntent.IMPROVE_CODE
42+
} else if (prompt.message.startsWith('Generate unit tests') && Auth.instance.isInternalAmazonUser()) {
43+
return UserIntent.GENERATE_UNIT_TESTS
3944
}
4045
return undefined
4146
}

packages/core/src/shared/handleUninstall.ts

Lines changed: 86 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,59 +7,113 @@
77

88
import * as vscode from 'vscode'
99
import { existsSync } from 'fs'
10+
import * as semver from 'semver'
1011
import { join } from 'path'
1112
import { getLogger } from './logger/logger'
1213
import { telemetry } from './telemetry'
1314
import { VSCODE_EXTENSION_ID } from './extensions'
1415
import { extensionVersion } from './vscode/env'
1516

1617
/**
17-
* Checks if the extension has been uninstalled by reading the .obsolete file
18-
* and comparing the number of obsolete extensions with the installed extensions.
18+
* Checks if an extension has been uninstalled and performs a callback if so.
19+
* This function differentiates between an uninstall and an auto-update.
1920
*
20-
* @param {string} extensionName - The name of the extension.
21-
* @param {string} extensionsDirPath - The path to the extensions directory.
22-
* @param {string} obsoleteFilePath - The path to the .obsolete file.
23-
* @param {function} callback - Action performed when extension is uninstalled.
24-
* @returns {void}
21+
* @param extensionId - The ID of the extension to check (e.g., VSCODE_EXTENSION_ID.awstoolkit)
22+
* @param extensionsPath - The file system path to the VS Code extensions directory
23+
* @param obsoletePath - The file system path to the .obsolete file
24+
* @param onUninstallCallback - A callback function to execute if the extension is uninstalled
2525
*/
2626
async function checkExtensionUninstall(
27-
extensionName: typeof VSCODE_EXTENSION_ID.awstoolkit | typeof VSCODE_EXTENSION_ID.amazonq,
28-
extensionsDirPath: string,
29-
obsoleteFilePath: string,
30-
callback: () => Promise<void>
27+
extensionId: typeof VSCODE_EXTENSION_ID.awstoolkit | typeof VSCODE_EXTENSION_ID.amazonq,
28+
extensionsPath: string,
29+
obsoletePath: string,
30+
onUninstallCallback: () => Promise<void>
3131
): Promise<void> {
32-
/**
33-
* Users can have multiple profiles with different versions of the extensions.
34-
*
35-
* This makes sure the callback is triggered only when an explicit extension with specific version is uninstalled.
36-
*/
37-
const extension = `${extensionName}-${extensionVersion}`
32+
const extensionFullName = `${extensionId}-${extensionVersion}`
33+
3834
try {
39-
const [obsoleteFileContent, extensionsDirContent] = await Promise.all([
40-
vscode.workspace.fs.readFile(vscode.Uri.file(obsoleteFilePath)),
41-
vscode.workspace.fs.readDirectory(vscode.Uri.file(extensionsDirPath)),
35+
const [obsoleteFileContent, extensionDirEntries] = await Promise.all([
36+
vscode.workspace.fs.readFile(vscode.Uri.file(obsoletePath)),
37+
vscode.workspace.fs.readDirectory(vscode.Uri.file(extensionsPath)),
4238
])
4339

44-
const installedExtensionsCount = extensionsDirContent
45-
.map(([name]) => name)
46-
.filter((name) => name.includes(extension)).length
47-
4840
const obsoleteExtensions = JSON.parse(obsoleteFileContent.toString())
49-
const obsoleteExtensionsCount = Object.keys(obsoleteExtensions).filter((id) => id.includes(extension)).length
50-
51-
if (installedExtensionsCount === obsoleteExtensionsCount) {
52-
await callback()
53-
telemetry.aws_extensionUninstalled.run((span) => {
54-
span.record({})
55-
})
56-
getLogger().info(`UninstallExtension: ${extension} uninstalled successfully`)
41+
const currentExtension = vscode.extensions.getExtension(extensionId)
42+
43+
if (!currentExtension) {
44+
// Check if the extension was previously installed and is now in the obsolete list
45+
const wasObsolete = Object.keys(obsoleteExtensions).some((id) => id.startsWith(extensionId))
46+
if (wasObsolete) {
47+
await handleUninstall(extensionFullName, onUninstallCallback)
48+
}
49+
} else {
50+
// Check if there's a newer version in the extensions directory
51+
const newerVersionExists = checkForNewerVersion(
52+
extensionDirEntries,
53+
extensionId,
54+
currentExtension.packageJSON.version
55+
)
56+
57+
if (!newerVersionExists) {
58+
// No newer version exists, so this is likely an uninstall
59+
await handleUninstall(extensionFullName, onUninstallCallback)
60+
} else {
61+
getLogger().info(`UpdateExtension: ${extensionFullName} is being updated`)
62+
}
5763
}
5864
} catch (error) {
5965
getLogger().error(`UninstallExtension: Failed to check .obsolete: ${error}`)
6066
}
6167
}
6268

69+
/**
70+
* Checks if a newer version of the extension exists in the extensions directory.
71+
* The isExtensionInstalled fn is used to determine if the extension is installed using the vscode API
72+
* whereas this function checks for the newer version in the extension directory for scenarios where
73+
* the old extension is not un-installed and the new extension in downloaded but not installed.
74+
*
75+
* @param onUninstallCallback - A callback function to execute if the extension is uninstalled
76+
* @param isExtensionInstalled - A function to check if the extension is installed
77+
* @param dirEntries - The entries in the extensions directory
78+
* @param extensionId - The ID of the extension to check
79+
* @param currentVersion - The current version of the extension
80+
* @returns True if a newer version exists, false otherwise
81+
*/
82+
83+
function checkForNewerVersion(
84+
dirEntries: [string, vscode.FileType][],
85+
extensionId: string,
86+
currentVersion: string
87+
): boolean {
88+
const versionRegex = new RegExp(`^${extensionId}-(.+)$`)
89+
90+
return dirEntries
91+
.map(([name]) => name)
92+
.filter((name) => name.startsWith(extensionId))
93+
.some((name) => {
94+
const match = name.match(versionRegex)
95+
if (match && match[1]) {
96+
const version = semver.valid(semver.coerce(match[1]))
97+
return version !== null && semver.gt(version, currentVersion)
98+
}
99+
return false
100+
})
101+
}
102+
103+
/**
104+
* Handles the uninstall process by calling the callback and logging the event.
105+
*
106+
* @param extensionFullName - The full name of the extension including version
107+
* @param callback - The callback function to execute on uninstall
108+
*/
109+
async function handleUninstall(extensionFullName: string, callback: () => Promise<void>): Promise<void> {
110+
await callback()
111+
telemetry.aws_extensionUninstalled.run((span) => {
112+
span.record({})
113+
})
114+
getLogger().info(`UninstallExtension: ${extensionFullName} uninstalled successfully`)
115+
}
116+
63117
/**
64118
* Sets up a file system watcher to monitor the .obsolete file for changes and handle
65119
* extension un-installation if the extension is marked as obsolete.

0 commit comments

Comments
 (0)