|
| 1 | +/*! |
| 2 | + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | +import { Writable } from 'stream' |
| 6 | +import { FsRead, FsReadParams } from './fsRead' |
| 7 | +import { FsWrite, FsWriteParams } from './fsWrite' |
| 8 | +import { ExecuteBash, ExecuteBashParams } from './executeBash' |
| 9 | +import { ToolResult, ToolResultContentBlock, ToolResultStatus, ToolUse } from '@amzn/codewhisperer-streaming' |
| 10 | +import { InvokeOutput } from './toolShared' |
| 11 | + |
| 12 | +export enum ToolType { |
| 13 | + FsRead = 'fsRead', |
| 14 | + FsWrite = 'fsWrite', |
| 15 | + ExecuteBash = 'executeBash', |
| 16 | +} |
| 17 | + |
| 18 | +export type Tool = |
| 19 | + | { type: ToolType.FsRead; tool: FsRead } |
| 20 | + | { type: ToolType.FsWrite; tool: FsWrite } |
| 21 | + | { type: ToolType.ExecuteBash; tool: ExecuteBash } |
| 22 | + |
| 23 | +export class ToolUtils { |
| 24 | + static displayName(tool: Tool): string { |
| 25 | + switch (tool.type) { |
| 26 | + case ToolType.FsRead: |
| 27 | + return 'Read from filesystem' |
| 28 | + case ToolType.FsWrite: |
| 29 | + return 'Write to filesystem' |
| 30 | + case ToolType.ExecuteBash: |
| 31 | + return 'Execute shell command' |
| 32 | + } |
| 33 | + } |
| 34 | + |
| 35 | + static requiresAcceptance(tool: Tool) { |
| 36 | + switch (tool.type) { |
| 37 | + case ToolType.FsRead: |
| 38 | + return false |
| 39 | + case ToolType.FsWrite: |
| 40 | + return true |
| 41 | + case ToolType.ExecuteBash: |
| 42 | + return tool.tool.requiresAcceptance() |
| 43 | + } |
| 44 | + } |
| 45 | + |
| 46 | + static async invoke(tool: Tool, updates: Writable): Promise<InvokeOutput> { |
| 47 | + switch (tool.type) { |
| 48 | + case ToolType.FsRead: |
| 49 | + return tool.tool.invoke(updates) |
| 50 | + case ToolType.FsWrite: |
| 51 | + return tool.tool.invoke(updates) |
| 52 | + case ToolType.ExecuteBash: |
| 53 | + return tool.tool.invoke(updates) |
| 54 | + } |
| 55 | + } |
| 56 | + |
| 57 | + static queueDescription(tool: Tool, updates: Writable): void { |
| 58 | + switch (tool.type) { |
| 59 | + case ToolType.FsRead: |
| 60 | + tool.tool.queueDescription(updates) |
| 61 | + break |
| 62 | + case ToolType.FsWrite: |
| 63 | + tool.tool.queueDescription(updates) |
| 64 | + break |
| 65 | + case ToolType.ExecuteBash: |
| 66 | + tool.tool.queueDescription(updates) |
| 67 | + break |
| 68 | + } |
| 69 | + } |
| 70 | + |
| 71 | + static async validate(tool: Tool): Promise<void> { |
| 72 | + switch (tool.type) { |
| 73 | + case ToolType.FsRead: |
| 74 | + return tool.tool.validate() |
| 75 | + case ToolType.FsWrite: |
| 76 | + return tool.tool.validate() |
| 77 | + case ToolType.ExecuteBash: |
| 78 | + return tool.tool.validate() |
| 79 | + } |
| 80 | + } |
| 81 | + |
| 82 | + static tryFromToolUse(value: ToolUse): Tool | ToolResult { |
| 83 | + const mapErr = (parseError: any): ToolResult => ({ |
| 84 | + toolUseId: value.toolUseId, |
| 85 | + content: [ |
| 86 | + { |
| 87 | + type: 'text', |
| 88 | + text: `Failed to validate tool parameters: ${parseError}. The model has either suggested tool parameters which are incompatible with the existing tools, or has suggested one or more tool that does not exist in the list of known tools.`, |
| 89 | + } as ToolResultContentBlock, |
| 90 | + ], |
| 91 | + status: ToolResultStatus.ERROR, |
| 92 | + }) |
| 93 | + |
| 94 | + try { |
| 95 | + switch (value.name) { |
| 96 | + case ToolType.FsRead: |
| 97 | + return { |
| 98 | + type: ToolType.FsRead, |
| 99 | + tool: new FsRead(value.input as unknown as FsReadParams), |
| 100 | + } |
| 101 | + case ToolType.FsWrite: |
| 102 | + return { |
| 103 | + type: ToolType.FsWrite, |
| 104 | + tool: new FsWrite(value.input as unknown as FsWriteParams), |
| 105 | + } |
| 106 | + case ToolType.ExecuteBash: |
| 107 | + return { |
| 108 | + type: ToolType.ExecuteBash, |
| 109 | + tool: new ExecuteBash(value.input as unknown as ExecuteBashParams), |
| 110 | + } |
| 111 | + default: |
| 112 | + return { |
| 113 | + toolUseId: value.toolUseId, |
| 114 | + content: [ |
| 115 | + { |
| 116 | + type: 'text', |
| 117 | + text: `The tool, "${value.name}" is not supported by the client`, |
| 118 | + } as ToolResultContentBlock, |
| 119 | + ], |
| 120 | + } |
| 121 | + } |
| 122 | + } catch (error) { |
| 123 | + return mapErr(error) |
| 124 | + } |
| 125 | + } |
| 126 | +} |
0 commit comments