Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,8 @@
"umd-compat-loader": "^2.1.2",
"vue-loader": "^17.2.2",
"vue-style-loader": "^4.1.3",
"webfont": "^11.2.26"
"webfont": "^11.2.26",
"shlex": "^2.1.2"
},
"dependencies": {
"@amzn/amazon-q-developer-streaming-client": "file:../../src.gen/@amzn/amazon-q-developer-streaming-client",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,9 @@ export class ChatController {
try {
await ToolUtils.validate(tool)

const chatStream = new ChatStream(this.messenger, tabID, triggerID, toolUse)
const chatStream = new ChatStream(this.messenger, tabID, triggerID, toolUse, {
requiresAcceptance: false,
})
const output = await ToolUtils.invoke(tool, chatStream)

toolResults.push({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import { ToolType, ToolUtils } from '../../../tools/toolUtils'
import { ChatStream } from '../../../tools/chatStream'
import path from 'path'
import { getWorkspaceForFile } from '../../../../shared/utilities/workspaceUtils'
import { CommandValidation } from '../../../tools/executeBash'

export type StaticTextResponseType = 'quick-action-help' | 'onboarding-help' | 'transform' | 'help'

Expand Down Expand Up @@ -220,11 +221,12 @@ export class Messenger {
if (tool.type === ToolType.FsWrite) {
session.setShowDiffOnFileWrite(true)
}
const requiresAcceptance = ToolUtils.requiresAcceptance(tool)
const chatStream = new ChatStream(this, tabID, triggerID, toolUse, requiresAcceptance)
const validation = ToolUtils.requiresAcceptance(tool)

const chatStream = new ChatStream(this, tabID, triggerID, toolUse, validation)
ToolUtils.queueDescription(tool, chatStream)

if (!requiresAcceptance) {
if (!validation.requiresAcceptance) {
// Need separate id for read tool and safe bash command execution as 'confirm-tool-use' id is required to change button status from `Confirm` to `Confirmed` state in cwChatConnector.ts which will impact generic tool execution.
this.dispatcher.sendCustomFormActionMessage(
new CustomFormActionMessage(tabID, {
Expand Down Expand Up @@ -432,17 +434,21 @@ export class Messenger {
tabID: string,
triggerID: string,
toolUse: ToolUse | undefined,
requiresAcceptance = false
validation: CommandValidation
) {
const buttons: ChatItemButton[] = []
let fileList: ChatItemContent['fileList'] = undefined
if (requiresAcceptance && toolUse?.name === ToolType.ExecuteBash) {
if (validation.requiresAcceptance && toolUse?.name === ToolType.ExecuteBash) {
buttons.push({
id: 'confirm-tool-use',
text: 'Confirm',
position: 'outside',
status: 'info',
})

if (validation.warning) {
message = validation.warning + message
}
} else if (toolUse?.name === ToolType.FsWrite) {
// FileList
const absoluteFilePath = (toolUse?.input as any).path
Expand Down Expand Up @@ -471,7 +477,7 @@ export class Messenger {
this.dispatcher.sendChatMessage(
new ChatMessage(
{
message,
message: message,
messageType: 'answer-part',
followUps: undefined,
followUpsHeader: undefined,
Expand Down
7 changes: 4 additions & 3 deletions packages/core/src/codewhispererChat/tools/chatStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { Writable } from 'stream'
import { getLogger } from '../../shared/logger/logger'
import { Messenger } from '../controllers/chat/messenger/messenger'
import { ToolUse } from '@amzn/codewhisperer-streaming'
import { CommandValidation } from './executeBash'

/**
* A writable stream that feeds each chunk/line to the chat UI.
Expand All @@ -20,7 +21,7 @@ export class ChatStream extends Writable {
private readonly tabID: string,
private readonly triggerID: string,
private readonly toolUse: ToolUse | undefined,
private readonly requiresAcceptance = false,
private readonly validation: CommandValidation,
private readonly logger = getLogger('chatStream')
) {
super()
Expand All @@ -37,7 +38,7 @@ export class ChatStream extends Writable {
this.tabID,
this.triggerID,
this.toolUse,
this.requiresAcceptance
this.validation
)
callback()
}
Expand All @@ -49,7 +50,7 @@ export class ChatStream extends Writable {
this.tabID,
this.triggerID,
this.toolUse,
this.requiresAcceptance
this.validation
)
}
callback()
Expand Down
216 changes: 167 additions & 49 deletions packages/core/src/codewhispererChat/tools/executeBash.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,130 @@

import { Writable } from 'stream'
import { getLogger } from '../../shared/logger/logger'
import { fs } from '../../shared/fs/fs' // e.g. for getUserHomeDir()
import { fs } from '../../shared/fs/fs'
import { ChildProcess, ChildProcessOptions } from '../../shared/utilities/processUtils'
import { InvokeOutput, OutputKind, sanitizePath } from './toolShared'
import { split } from 'shlex'

export const readOnlyCommands: string[] = ['ls', 'cat', 'echo', 'pwd', 'which', 'head', 'tail']
export enum CommandCategory {
ReadOnly,
HighRisk,
Destructive,
}

export const dangerousPatterns = new Set(['<(', '$(', '`', '>', '&&', '||'])
export const commandCategories = new Map<string, CommandCategory>([
// ReadOnly commands
['ls', CommandCategory.ReadOnly],
['cat', CommandCategory.ReadOnly],
['bat', CommandCategory.ReadOnly],
['pwd', CommandCategory.ReadOnly],
['echo', CommandCategory.ReadOnly],
['file', CommandCategory.ReadOnly],
['less', CommandCategory.ReadOnly],
['more', CommandCategory.ReadOnly],
['tree', CommandCategory.ReadOnly],
['find', CommandCategory.ReadOnly],
['top', CommandCategory.ReadOnly],
['htop', CommandCategory.ReadOnly],
['ps', CommandCategory.ReadOnly],
['df', CommandCategory.ReadOnly],
['du', CommandCategory.ReadOnly],
['free', CommandCategory.ReadOnly],
['uname', CommandCategory.ReadOnly],
['date', CommandCategory.ReadOnly],
['whoami', CommandCategory.ReadOnly],
['which', CommandCategory.ReadOnly],
['ping', CommandCategory.ReadOnly],
['ifconfig', CommandCategory.ReadOnly],
['ip', CommandCategory.ReadOnly],
['netstat', CommandCategory.ReadOnly],
['ss', CommandCategory.ReadOnly],
['dig', CommandCategory.ReadOnly],
['grep', CommandCategory.ReadOnly],
['wc', CommandCategory.ReadOnly],
['sort', CommandCategory.ReadOnly],
['diff', CommandCategory.ReadOnly],
['head', CommandCategory.ReadOnly],
['tail', CommandCategory.ReadOnly],

// HighRisk commands
['chmod', CommandCategory.HighRisk],
['chown', CommandCategory.HighRisk],
['mv', CommandCategory.HighRisk],
['cp', CommandCategory.HighRisk],
['ln', CommandCategory.HighRisk],
['mount', CommandCategory.HighRisk],
['umount', CommandCategory.HighRisk],
['kill', CommandCategory.HighRisk],
['killall', CommandCategory.HighRisk],
['pkill', CommandCategory.HighRisk],
['iptables', CommandCategory.HighRisk],
['route', CommandCategory.HighRisk],
['systemctl', CommandCategory.HighRisk],
['service', CommandCategory.HighRisk],
['crontab', CommandCategory.HighRisk],
['at', CommandCategory.HighRisk],
['tar', CommandCategory.HighRisk],
['awk', CommandCategory.HighRisk],
['sed', CommandCategory.HighRisk],
['wget', CommandCategory.HighRisk],
['curl', CommandCategory.HighRisk],
['nc', CommandCategory.HighRisk],
['ssh', CommandCategory.HighRisk],
['scp', CommandCategory.HighRisk],
['ftp', CommandCategory.HighRisk],
['sftp', CommandCategory.HighRisk],
['rsync', CommandCategory.HighRisk],
['chroot', CommandCategory.HighRisk],
['lsof', CommandCategory.HighRisk],
['strace', CommandCategory.HighRisk],
['gdb', CommandCategory.HighRisk],

// Destructive commands
['rm', CommandCategory.Destructive],
['dd', CommandCategory.Destructive],
['mkfs', CommandCategory.Destructive],
['fdisk', CommandCategory.Destructive],
['shutdown', CommandCategory.Destructive],
['reboot', CommandCategory.Destructive],
['poweroff', CommandCategory.Destructive],
['sudo', CommandCategory.Destructive],
['su', CommandCategory.Destructive],
['useradd', CommandCategory.Destructive],
['userdel', CommandCategory.Destructive],
['passwd', CommandCategory.Destructive],
['visudo', CommandCategory.Destructive],
['insmod', CommandCategory.Destructive],
['rmmod', CommandCategory.Destructive],
['modprobe', CommandCategory.Destructive],
['apt', CommandCategory.Destructive],
['yum', CommandCategory.Destructive],
['dnf', CommandCategory.Destructive],
['pacman', CommandCategory.Destructive],
['perl', CommandCategory.Destructive],
['python', CommandCategory.Destructive],
['bash', CommandCategory.Destructive],
['sh', CommandCategory.Destructive],
['exec', CommandCategory.Destructive],
['eval', CommandCategory.Destructive],
['xargs', CommandCategory.Destructive],
])
export const maxBashToolResponseSize: number = 1024 * 1024 // 1MB
export const lineCount: number = 1024
export const dangerousPatterns: string[] = ['|', '<(', '$(', '`', '>', '&&', '||']
export const destructiveCommandWarningMessage = '⚠️ WARNING: Destructive command detected:\n\n'
export const highRiskCommandWarningMessage = '⚠️ WARNING: High risk command detected:\n\n'

export interface ExecuteBashParams {
command: string
cwd?: string
}

export interface CommandValidation {
requiresAcceptance: boolean
warning?: string
}

export class ExecuteBash {
private readonly command: string
private readonly workingDirectory?: string
Expand All @@ -34,7 +144,7 @@ export class ExecuteBash {
throw new Error('Bash command cannot be empty.')
}

const args = ExecuteBash.parseCommand(this.command)
const args = split(this.command)
if (!args || args.length === 0) {
throw new Error('No command found.')
}
Expand All @@ -46,22 +156,67 @@ export class ExecuteBash {
}
}

public requiresAcceptance(): boolean {
public requiresAcceptance(): CommandValidation {
try {
const args = ExecuteBash.parseCommand(this.command)
const args = split(this.command)
if (!args || args.length === 0) {
return true
return { requiresAcceptance: true }
}

if (args.some((arg) => dangerousPatterns.some((pattern) => arg.includes(pattern)))) {
return true
// Split commands by pipe and process each segment
let currentCmd: string[] = []
const allCommands: string[][] = []

for (const arg of args) {
if (arg === '|') {
if (currentCmd.length > 0) {
allCommands.push(currentCmd)
}
currentCmd = []
} else if (arg.includes('|')) {
return { requiresAcceptance: true }
} else {
currentCmd.push(arg)
}
}

if (currentCmd.length > 0) {
allCommands.push(currentCmd)
}

const command = args[0]
return !readOnlyCommands.includes(command)
for (const cmdArgs of allCommands) {
if (cmdArgs.length === 0) {
return { requiresAcceptance: true }
}

const command = cmdArgs[0]
const category = commandCategories.get(command)

switch (category) {
case CommandCategory.Destructive:
return { requiresAcceptance: true, warning: destructiveCommandWarningMessage }
case CommandCategory.HighRisk:
return {
requiresAcceptance: true,
warning: highRiskCommandWarningMessage,
}
case CommandCategory.ReadOnly:
if (
cmdArgs.some((arg) =>
Array.from(dangerousPatterns).some((pattern) => arg.includes(pattern))
)
) {
return { requiresAcceptance: true, warning: highRiskCommandWarningMessage }
}
return { requiresAcceptance: false }
default:
return { requiresAcceptance: true, warning: highRiskCommandWarningMessage }
}
}
return { requiresAcceptance: true }
} catch (error) {
this.logger.warn(`Error while checking acceptance: ${(error as Error).message}`)
return true
return { requiresAcceptance: true }
}
}

Expand Down Expand Up @@ -167,43 +322,6 @@ export class ExecuteBash {
return output
}

private static parseCommand(command: string): string[] | undefined {
const result: string[] = []
let current = ''
let inQuote: string | undefined
let escaped = false

for (const char of command) {
if (escaped) {
current += char
escaped = false
} else if (char === '\\') {
escaped = true
} else if (inQuote) {
if (char === inQuote) {
inQuote = undefined
} else {
current += char
}
} else if (char === '"' || char === "'") {
inQuote = char
} else if (char === ' ' || char === '\t') {
if (current) {
result.push(current)
current = ''
}
} else {
current += char
}
}

if (current) {
result.push(current)
}

return result
}

public queueDescription(updates: Writable): void {
updates.write(`I will run the following shell command:\n`)
updates.write('```bash\n' + this.command + '\n```')
Expand Down
Loading
Loading