Skip to content

Commit cab3193

Browse files
authored
Feat: Add option to require approval before tool execution (#140)
Adds a new CLI option, `--approve-tool-run` (or via config setting), that when enabled, prompts the user to approve a tool's execution before it runs. This option is disabled by default to maintain existing behavior.
1 parent d3cae7c commit cab3193

File tree

6 files changed

+213
-14
lines changed

6 files changed

+213
-14
lines changed

cmd/root.go

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ var (
3838
streamFlag bool // Enable streaming output
3939
compactMode bool // Enable compact output mode
4040
scriptMCPConfig *config.Config // Used to override config in script mode
41+
approveToolRun bool
4142

4243
// Session management
4344
saveSessionPath string
@@ -302,6 +303,8 @@ func init() {
302303
BoolVar(&compactMode, "compact", false, "enable compact output mode without fancy styling")
303304
rootCmd.PersistentFlags().
304305
BoolVar(&noHooks, "no-hooks", false, "disable all hooks execution")
306+
rootCmd.PersistentFlags().
307+
BoolVar(&approveToolRun, "approve-tool-run", false, "enable requiring user approval for every tool call")
305308

306309
// Session management flags
307310
rootCmd.PersistentFlags().
@@ -347,6 +350,7 @@ func init() {
347350
viper.BindPFlag("num-gpu-layers", rootCmd.PersistentFlags().Lookup("num-gpu-layers"))
348351
viper.BindPFlag("main-gpu", rootCmd.PersistentFlags().Lookup("main-gpu"))
349352
viper.BindPFlag("tls-skip-verify", rootCmd.PersistentFlags().Lookup("tls-skip-verify"))
353+
viper.BindPFlag("approve-tool-run", rootCmd.PersistentFlags().Lookup("approve-tool-run"))
350354

351355
// Defaults are already set in flag definitions, no need to duplicate in viper
352356

@@ -445,7 +449,8 @@ func runNormalMode(ctx context.Context) error {
445449
debugLogger = bufferedLogger
446450
}
447451

448-
mcpAgent, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{ModelConfig: modelConfig,
452+
mcpAgent, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
453+
ModelConfig: modelConfig,
449454
MCPConfig: mcpConfig,
450455
SystemPrompt: systemPrompt,
451456
MaxSteps: viper.GetInt("max-steps"),
@@ -743,7 +748,8 @@ func runNormalMode(ctx context.Context) error {
743748
return fmt.Errorf("--quiet flag can only be used with --prompt/-p")
744749
}
745750

746-
return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor)
751+
approveToolRun := viper.GetBool("approve-tool-run")
752+
return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor, approveToolRun)
747753
}
748754

749755
// AgenticLoopConfig configures the behavior of the unified agentic loop.
@@ -754,6 +760,7 @@ type AgenticLoopConfig struct {
754760
IsInteractive bool // true for interactive mode, false for non-interactive
755761
InitialPrompt string // initial prompt for non-interactive mode
756762
ContinueAfterRun bool // true to continue to interactive mode after initial run (--no-exit)
763+
ApproveToolRun bool // only used in interactive mode
757764

758765
// UI configuration
759766
Quiet bool // suppress all output except final response
@@ -1103,7 +1110,27 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
11031110
currentSpinner.Start()
11041111
}
11051112
},
1106-
streamingCallback, // Add streaming callback as the last parameter
1113+
// Add streaming callback handler
1114+
streamingCallback,
1115+
// Tool call approval handler - called before tool execution to get user approval
1116+
func(toolName, toolArgs string) (bool, error) {
1117+
if !config.IsInteractive || !config.ApproveToolRun {
1118+
return true, nil
1119+
}
1120+
if currentSpinner != nil {
1121+
currentSpinner.Stop()
1122+
currentSpinner = nil
1123+
}
1124+
allow, err := cli.GetToolApproval(toolName, toolArgs)
1125+
if err != nil {
1126+
return false, err
1127+
}
1128+
// Start spinner again for tool calls
1129+
currentSpinner = ui.NewSpinner("Thinking...")
1130+
currentSpinner.Start()
1131+
1132+
return allow, nil
1133+
},
11071134
)
11081135

11091136
// Make sure spinner is stopped if still running
@@ -1306,6 +1333,7 @@ func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.C
13061333
IsInteractive: false,
13071334
InitialPrompt: prompt,
13081335
ContinueAfterRun: noExit,
1336+
ApproveToolRun: false,
13091337
Quiet: quiet,
13101338
ServerNames: serverNames,
13111339
ToolNames: toolNames,
@@ -1318,12 +1346,13 @@ func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.C
13181346
}
13191347

13201348
// runInteractiveMode handles the interactive mode execution
1321-
func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager, hookExecutor *hooks.Executor) error {
1349+
func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager, hookExecutor *hooks.Executor, approveToolRun bool) error {
13221350
// Configure and run unified agentic loop
13231351
config := AgenticLoopConfig{
13241352
IsInteractive: true,
13251353
InitialPrompt: "",
13261354
ContinueAfterRun: false,
1355+
ApproveToolRun: approveToolRun,
13271356
Quiet: false,
13281357
ServerNames: serverNames,
13291358
ToolNames: toolNames,

internal/agent/agent.go

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"strings"
8+
"time"
9+
710
tea "github.com/charmbracelet/bubbletea"
811
"github.com/cloudwego/eino/components/model"
912
"github.com/cloudwego/eino/components/tool"
@@ -12,8 +15,6 @@ import (
1215
"github.com/mark3labs/mcphost/internal/config"
1316
"github.com/mark3labs/mcphost/internal/models"
1417
"github.com/mark3labs/mcphost/internal/tools"
15-
"strings"
16-
"time"
1718
)
1819

1920
// AgentConfig holds configuration options for creating a new Agent.
@@ -57,6 +58,10 @@ type StreamingResponseHandler func(content string)
5758
// It receives any text content that the model generates alongside tool calls.
5859
type ToolCallContentHandler func(content string)
5960

61+
// ToolApprovalHandler is a function type for handling user approval of tool calls.
62+
// It receives the tool name and arguments, and returns true if the user approves.
63+
type ToolApprovalHandler func(toolName, toolArgs string) (bool, error)
64+
6065
// Agent represents an AI agent with MCP tool integration and real-time tool call display.
6166
// It manages the interaction between an LLM and various tools through the MCP protocol.
6267
type Agent struct {
@@ -128,17 +133,17 @@ type GenerateWithLoopResult struct {
128133
// It handles the conversation flow, executing tools as needed and invoking callbacks for various events.
129134
// This method does not support streaming responses; use GenerateWithLoopAndStreaming for streaming support.
130135
func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message,
131-
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*GenerateWithLoopResult, error) {
132-
133-
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult, onResponse, onToolCallContent, nil)
136+
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onToolApproval ToolApprovalHandler,
137+
) (*GenerateWithLoopResult, error) {
138+
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult, onResponse, onToolCallContent, nil, onToolApproval)
134139
}
135140

136141
// GenerateWithLoopAndStreaming processes messages with a custom loop that displays tool calls in real-time and supports streaming callbacks.
137142
// It handles the conversation flow, executing tools as needed and invoking callbacks for various events including streaming chunks.
138143
// The onStreamingResponse callback is invoked for each content chunk during streaming if streaming is enabled.
139144
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*schema.Message,
140-
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onStreamingResponse StreamingResponseHandler) (*GenerateWithLoopResult, error) {
141-
145+
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onStreamingResponse StreamingResponseHandler, onToolApproval ToolApprovalHandler,
146+
) (*GenerateWithLoopResult, error) {
142147
// Create a copy of messages to avoid modifying the original
143148
workingMessages := make([]*schema.Message, len(messages))
144149
copy(workingMessages, messages)
@@ -200,6 +205,19 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*sc
200205

201206
// Handle tool calls
202207
for _, toolCall := range response.ToolCalls {
208+
if onToolApproval != nil {
209+
approved, err := onToolApproval(toolCall.Function.Name, toolCall.Function.Arguments)
210+
if err != nil {
211+
return nil, err
212+
}
213+
if !approved {
214+
rejectedMsg := fmt.Sprintf("The user did not allow tool call %s. Reason: User cancelled.", toolCall.Function.Name)
215+
toolMessage := schema.ToolMessage(rejectedMsg, toolCall.ID)
216+
workingMessages = append(workingMessages, toolMessage)
217+
continue
218+
}
219+
}
220+
203221
// Notify about tool call
204222
if onToolCall != nil {
205223
onToolCall(toolCall.Function.Name, toolCall.Function.Arguments)

internal/config/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ type Config struct {
166166
Stream *bool `json:"stream,omitempty" yaml:"stream,omitempty"`
167167
Theme any `json:"theme" yaml:"theme"`
168168
MarkdownTheme any `json:"markdown-theme" yaml:"markdown-theme"`
169+
ApproveToolRun bool `json:"approve-tool-run" yaml:"approve-tool-run"`
169170

170171
// Model generation parameters
171172
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`

internal/ui/cli.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ import (
1313
"golang.org/x/term"
1414
)
1515

16-
var (
17-
promptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12"))
18-
)
16+
var promptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12"))
1917

2018
// CLI manages the command-line interface for MCPHost, providing message rendering,
2119
// user input handling, and display management. It supports both standard and compact
@@ -377,6 +375,22 @@ func (c *CLI) IsSlashCommand(input string) bool {
377375
return strings.HasPrefix(input, "/")
378376
}
379377

378+
// GetToolApproval asks the user for permission to execute the tool with the given
379+
// arguments. Returns true if the user approves.
380+
func (c *CLI) GetToolApproval(toolName, toolArgs string) (bool, error) {
381+
input := NewToolApprovalInput(toolName, toolArgs, c.width)
382+
p := tea.NewProgram(input)
383+
finalModel, err := p.Run()
384+
if err != nil {
385+
return false, err
386+
}
387+
388+
if finalInput, ok := finalModel.(*ToolApprovalInput); ok {
389+
return finalInput.approved, nil
390+
}
391+
return false, fmt.Errorf("GetToolApproval: unexpected error type")
392+
}
393+
380394
// SlashCommandResult encapsulates the outcome of processing a slash command,
381395
// indicating whether the command was recognized and handled, and whether the
382396
// conversation history should be cleared as a result of the command.

internal/ui/tool_approval_input.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package ui
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
7+
"github.com/charmbracelet/bubbles/textarea"
8+
tea "github.com/charmbracelet/bubbletea"
9+
"github.com/charmbracelet/lipgloss"
10+
)
11+
12+
type ToolApprovalInput struct {
13+
textarea textarea.Model
14+
toolName string
15+
toolArgs string
16+
width int
17+
selected bool // true when "yes" is highlighted and false when "no" is
18+
approved bool
19+
done bool
20+
}
21+
22+
func NewToolApprovalInput(toolName, toolArgs string, width int) *ToolApprovalInput {
23+
ta := textarea.New()
24+
ta.Placeholder = ""
25+
ta.ShowLineNumbers = false
26+
ta.CharLimit = 1000
27+
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
28+
ta.SetHeight(4) // Default to 3 lines like huh
29+
ta.Focus()
30+
31+
// Style the textarea to match huh theme
32+
ta.FocusedStyle.Base = lipgloss.NewStyle()
33+
ta.FocusedStyle.Placeholder = lipgloss.NewStyle().Foreground(lipgloss.Color("240"))
34+
ta.FocusedStyle.Text = lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
35+
ta.FocusedStyle.Prompt = lipgloss.NewStyle()
36+
ta.FocusedStyle.CursorLine = lipgloss.NewStyle()
37+
ta.Cursor.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("39"))
38+
39+
return &ToolApprovalInput{
40+
textarea: ta,
41+
toolName: toolName,
42+
toolArgs: toolArgs,
43+
width: width,
44+
selected: true,
45+
}
46+
}
47+
48+
func (t *ToolApprovalInput) Init() tea.Cmd {
49+
return textarea.Blink
50+
}
51+
52+
func (t *ToolApprovalInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
53+
switch msg := msg.(type) {
54+
case tea.KeyMsg:
55+
switch msg.String() {
56+
case "y", "Y":
57+
t.approved = true
58+
t.done = true
59+
return t, tea.Quit
60+
case "n", "N":
61+
t.approved = false
62+
t.done = true
63+
return t, tea.Quit
64+
case "left":
65+
t.selected = true
66+
return t, nil
67+
case "right":
68+
t.selected = false
69+
return t, nil
70+
case "enter":
71+
t.approved = t.selected
72+
t.done = true
73+
return t, tea.Quit
74+
case "esc", "ctrl+c":
75+
t.approved = false
76+
t.done = true
77+
return t, tea.Quit
78+
}
79+
}
80+
return t, nil
81+
}
82+
83+
func (t *ToolApprovalInput) View() string {
84+
if t.done {
85+
return "we are done"
86+
}
87+
// Add left padding to entire component (2 spaces like other UI elements)
88+
containerStyle := lipgloss.NewStyle().PaddingLeft(2)
89+
90+
// Title
91+
titleStyle := lipgloss.NewStyle().
92+
Foreground(lipgloss.Color("252")).
93+
MarginBottom(1)
94+
95+
// Input box with huh-like styling
96+
inputBoxStyle := lipgloss.NewStyle().
97+
Border(lipgloss.ThickBorder()).
98+
BorderLeft(true).
99+
BorderRight(false).
100+
BorderTop(false).
101+
BorderBottom(false).
102+
BorderForeground(lipgloss.Color("39")).
103+
PaddingLeft(1).
104+
Width(t.width - 2) // Account for container padding
105+
106+
// Style for the currently selected/highlighted option
107+
selectedStyle := lipgloss.NewStyle().
108+
Foreground(lipgloss.Color("42")). // Bright green
109+
Bold(true).
110+
Underline(true)
111+
112+
// Style for the unselected/unhighlighted option
113+
unselectedStyle := lipgloss.NewStyle().
114+
Foreground(lipgloss.Color("240")) // Dark gray
115+
116+
// Build the view
117+
var view strings.Builder
118+
view.WriteString(titleStyle.Render("Allow tool execution"))
119+
view.WriteString("\n")
120+
details := fmt.Sprintf("Tool: %s\nArguments: %s\n\n", t.toolName, t.toolArgs)
121+
view.WriteString(details)
122+
view.WriteString("Allow tool execution: ")
123+
124+
var yesText, noText string
125+
if t.selected {
126+
yesText = selectedStyle.Render("[y]es")
127+
noText = unselectedStyle.Render("[n]o")
128+
} else {
129+
yesText = unselectedStyle.Render("[y]es")
130+
noText = selectedStyle.Render("[n]o")
131+
}
132+
view.WriteString(yesText + "/" + noText + "\n")
133+
134+
return containerStyle.Render(inputBoxStyle.Render(view.String()))
135+
}

sdk/mcphost.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ func (m *MCPHost) Prompt(ctx context.Context, message string) (string, error) {
142142
nil, // onToolResult
143143
nil, // onResponse
144144
nil, // onToolCallContent
145+
nil, // onToolApproval
145146
)
146147
if err != nil {
147148
return "", err
@@ -181,6 +182,7 @@ func (m *MCPHost) PromptWithCallbacks(
181182
nil, // onResponse
182183
nil, // onToolCallContent
183184
onStreaming,
185+
nil, // onToolApproval
184186
)
185187
if err != nil {
186188
return "", err

0 commit comments

Comments
 (0)