Skip to content

Commit d80301e

Browse files
committed
feat(tool call approval)
Implement tool call approval. Default implementation uses floating windows. A nice enhancement would be to look for enhanced vim.ui.select implementations and use those (e.g. snacks)
1 parent 8e315fd commit d80301e

File tree

6 files changed

+291
-13
lines changed

6 files changed

+291
-13
lines changed

lua/eca/approve.lua

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
local M = {}
2+
3+
---@param tool_call eca.ToolCallRun
4+
function M.get_preview_lines(tool_call)
5+
if not tool_call.details then
6+
local arguments = vim.split(vim.inspect(tool_call.arguments), "\n")
7+
local messages = {}
8+
if tool_call.summary then
9+
table.insert(messages, "Summary: " .. tool_call.summary)
10+
end
11+
table.insert(messages, "Tool Name: " .. tool_call.name)
12+
table.insert(messages, "Tool Type: " .. tool_call.origin)
13+
table.insert(messages, "Tool Arguments: ")
14+
for _, v in pairs(arguments) do
15+
table.insert(messages, v)
16+
end
17+
return messages
18+
end
19+
local lines = vim.split(tool_call.details.diff, "\n")
20+
return { tool_call.details.path, unpack(lines) }
21+
end
22+
23+
---@param lines string[]
24+
---@return {row: number, col: number, width: number, height: number}
25+
local function get_position(lines)
26+
local gheight = math.floor(
27+
vim.api.nvim_list_uis() and vim.api.nvim_list_uis()[1] and vim.api.nvim_list_uis()[1].height or vim.o.lines
28+
)
29+
local gwidth = math.floor(
30+
vim.api.nvim_list_uis() and vim.api.nvim_list_uis()[1] and vim.api.nvim_list_uis()[1].width or vim.o.columns
31+
)
32+
local height = #lines > 10 and 35 or #lines
33+
local width = 0
34+
for _, line in ipairs(lines) do
35+
if #line > width then
36+
width = #line
37+
end
38+
end
39+
return {
40+
row = (gheight - height) * 0.5,
41+
col = (gwidth - width) * 0.5,
42+
width = math.floor(width * 1.5),
43+
height = height,
44+
}
45+
end
46+
47+
---@param tool_call eca.ToolCallRun
48+
---@param on_accept function
49+
---@param on_deny function
50+
function M.display_preview_lines(tool_call, on_accept, on_deny)
51+
local lines = M.get_preview_lines(tool_call)
52+
local buf = vim.api.nvim_create_buf(false, false)
53+
vim.api.nvim_buf_set_lines(buf, 0, -1, false, lines)
54+
vim.api.nvim_set_option_value("modifiable", false, { buf = buf })
55+
local position = get_position(lines)
56+
local title = tool_call.summary or tool_call.name
57+
local win = vim.api.nvim_open_win(buf, true, {
58+
border = "single",
59+
title = "Approve Tool Call(y/n): " .. title,
60+
relative = "editor",
61+
row = position.row,
62+
col = position.col,
63+
width = position.width,
64+
height = position.height,
65+
})
66+
if tool_call.details then
67+
vim.api.nvim_set_option_value("filetype", "diff", { buf = buf })
68+
else
69+
vim.api.nvim_set_option_value("number", false, { win = win })
70+
vim.api.nvim_set_option_value("relativenumber", false, { win = win })
71+
end
72+
73+
vim.keymap.set({ "n", "i" }, "y", "", {
74+
buffer = buf,
75+
callback = function()
76+
vim.api.nvim_win_close(win, true)
77+
vim.api.nvim_buf_delete(buf, { force = true })
78+
if on_accept then
79+
on_accept()
80+
end
81+
end,
82+
})
83+
vim.keymap.set({ "n", "i" }, "n", "", {
84+
buffer = buf,
85+
callback = function()
86+
vim.api.nvim_win_close(win, true)
87+
vim.api.nvim_buf_delete(buf, { force = true })
88+
if on_deny then
89+
on_deny()
90+
end
91+
end,
92+
})
93+
end
94+
95+
---@param tool_call eca.ToolCallRun
96+
---@param on_accept function
97+
---@param on_deny function
98+
function M.approve_tool_call(tool_call, on_accept, on_deny)
99+
M.display_preview_lines(tool_call, on_accept, on_deny)
100+
end
101+
return M

lua/eca/mediator.lua

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ end
1212

1313
---@param method string
1414
---@param params eca.MessageParams
15-
---@param callback fun(err: string, result: table)
15+
---@param callback? fun(err?: string, result?: table)
1616
function mediator:send(method, params, callback)
1717
if not self.server:is_running() then
18-
callback("Server is not running, please start the server", nil)
19-
return
18+
if callback then
19+
callback("Server is not running, please start the server", nil)
20+
end
21+
require("eca.logger").notify("Server is not rnning, please start the server", vim.log.levels.WARN)
2022
end
2123
self.server:send_request(method, params, callback)
2224
end

lua/eca/sidebar.lua

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ function M:_set_welcome_content()
880880
"- **RepoMap**: Use `:EcaAddRepoMap` to add repository structure context",
881881
"",
882882
"---",
883-
""
883+
"",
884884
}
885885

886886
Logger.debug("Setting welcome content for new chat")
@@ -1204,6 +1204,7 @@ function M:handle_chat_content_received(params)
12041204
end
12051205

12061206
local content = params.content
1207+
local chat_id = params.chatId
12071208

12081209
if content.type == "text" then
12091210
-- Handle streaming text content
@@ -1234,6 +1235,8 @@ function M:handle_chat_content_received(params)
12341235
self:_handle_tool_call_prepare(content)
12351236
-- IMPORTANT: Return immediately - do NOT display anything for toolCallPrepare
12361237
return
1238+
elseif content.type == "toolCallRun" then
1239+
self:render_tool_call(content, chat_id)
12371240
elseif content.type == "toolCallRunning" then
12381241
-- Show the accumulated tool call
12391242
self:_display_tool_call(content)
@@ -1274,6 +1277,16 @@ function M:handle_chat_content_received(params)
12741277
end
12751278
end
12761279

1280+
function M:render_tool_call(tool_content, chat_id)
1281+
if tool_content.type == "toolCallRun" and tool_content.manualApproval then
1282+
return require("eca.approve").approve_tool_call(tool_content, function()
1283+
self.mediator:send("chat/toolCallApprove", { chatId = chat_id, toolCallId = tool_content.id }, nil)
1284+
end, function()
1285+
self.mediator:send("chat/toolCallReject", { chatId = chat_id, toolCallId = tool_content.id }, nil)
1286+
end)
1287+
end
1288+
end
1289+
12771290
---@param text string
12781291
function M:_handle_streaming_text(text)
12791292
-- Only check for empty text

lua/eca/types.lua

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
---@meta
22

3+
---@alias eca.ChatModel string
4+
---@alias eca.ChatBehavior 'agent'|'plan'
5+
---@class eca.ServerCapabilities
6+
---@field welcomeMessage string
7+
---@field models eca.ChatModel[]
8+
---@field defaultModel eca.ChatModel
9+
---@field behaviors eca.ChatBehavior[]
10+
---@field defaultBehavior eca.ChatBehavior
11+
312
---@class eca.ChatContext
413
---@field type string
514
---@field path? string
@@ -9,7 +18,7 @@
918
---@field name? string
1019
---@field description? string
1120
---@field mimeType? string
12-
---@field server string
21+
---@field server? string
1322

1423
---@class eca.ChatCommand
1524
---@field name string
@@ -33,24 +42,30 @@
3342
---@field linesAdded integer the count of lines added in this change
3443
---@field linesRemoved integer the count of lines removed in this change
3544

36-
---@class eca.ToolCallRun
37-
---@field type 'toolCallRun'
45+
---@class eca.ToolCallPrepare
46+
---@field type 'toolCallPrepare'
3847
---@field origin eca.ToolCallOrigin
3948
---@field id string the id of the tool call
4049
---@field name string name of the tool
41-
---@field arguments {[string]: string} arguments of the tool call
50+
---@field argumentsText {[string]: string} arguments of the tool call
4251
---@field manualApproval boolean whether the call requires manual approval from the user
4352
---@field summary string summary text to present about this tool call
44-
---@field details eca.ToolCallDetails extra details for the call. clients may use this to present a different UX for this tool call.
53+
--- extra details for the call. clients may use this to present a different UX
54+
--- for this tool call.
55+
---@field details eca.ToolCallDetails
56+
---
4557

46-
---@class eca.ToolCallRunning
47-
---@field type 'toolCallRunning'
58+
---@class eca.ToolCallRun
59+
---@field type 'toolCallRun'
4860
---@field origin eca.ToolCallOrigin
4961
---@field id string the id of the tool call
5062
---@field name string name of the tool
5163
---@field arguments {[string]: string} arguments of the tool call
52-
---@field summary? string summary text to present about this tool call
53-
---@field details? eca.ToolCallDetails extra details for the call. clients may use this to present a different UX for this tool call.
64+
---@field manualApproval boolean whether the call requires manual approval from the user
65+
---@field summary string summary text to present about this tool call
66+
--- extra details for the call. clients may use this to present a different UX
67+
--- for this tool call.
68+
---@field details eca.ToolCallDetails
5469

5570
---@class eca.ToolCalled
5671
---@field type 'toolCalled'
@@ -61,3 +76,11 @@
6176
---@field outputs {type: 'text', text: string}[] the result of the tool call
6277
---@field summary? string summary text to present about the tool call
6378
---@field details? eca.ToolCallDetails extra details about the call
79+
80+
---@class eca.UsageContent
81+
---@field type 'usage'
82+
---@field messageInputTokens number
83+
---@field messageOutputTokens number
84+
---@field sessionTokens number
85+
---@field messageCost? string
86+
---@field sessionCost? string

tests/stubs/tool_calls.lua

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
local stubs = {}
2+
3+
stubs.read_file = {
4+
arguments = {
5+
path = "/Users/tgeorge/git/eca-nvim/hack/messages.lua",
6+
},
7+
id = "toolu_013zj73SHzZNoeE7kzD7qzb4",
8+
manualApproval = true,
9+
name = "eca_read_file",
10+
origin = "native",
11+
summary = "Reading file messages.lua",
12+
type = "toolCallRun",
13+
}
14+
15+
stubs.edit_file = {
16+
arguments = {
17+
new_content = 'local M = {}\n\n--- Show ECA messages using snacks.picker\nfunction M.show()\n local has_snacks, picker = pcall(require, "snacks.picker")\n if not has_snacks then\n vim.notify("snacks.picker is not available", vim.log.levels.ERROR)\n return\n end\n\n Snacks.picker(\n ---@type snacks.picker.Config\n {\n source = "eca messages",\n finder = function(opts, ctx)\n ---@type snacks.picker.finder.Item[]\n local items = {}\n for msg in vim.iter(require("eca").server.messages) do\n local decoded = vim.json.decode(msg.content)\n table.insert(items, {\n text = decoded.method,\n idx = decoded.id,\n preview = {\n text = vim.inspect(decoded),\n ft = "lua",\n},\n})\n end\n return items\n end,\n preview = "preview",\n format = "text",\n confirm = function(self, item, _)\n vim.fn.setreg("", item.preview.text)\n self:close()\n end,\n }\n )\nend\n\nreturn M',
18+
original_content = 'local has_snacks, picker = pcall(require, "snacks.picker")\nif has_snacks then\n Snacks.picker(\n ---@type snacks.picker.Config\n {\n source = "eca messages",\n finder = function(opts, ctx)\n ---@type snacks.picker.finder.Item[]\n local items = {}\n for msg in vim.iter(require("eca").server.messages) do\n local decoded = vim.json.decode(msg.content)\n table.insert(items, {\n text = decoded.method,\n idx = decoded.id,\n preview = {\n text = vim.inspect(decoded),\n ft = "lua",\n },\n })\n end\n return items\n end,\n preview = "preview",\n format = "text",\n confirm = function(self, item, _)\n vim.fn.setreg("", item.preview.text)\n self:close()\n end,\n }\n )\nend',
19+
path = "/Users/tgeorge/git/eca-nvim/hack/messages.lua",
20+
},
21+
details = {
22+
diff = '@@ -1, 5 +1, 13 @@\n-local has_snacks, picker = pcall(require, "snacks.picker")\n-if has_snacks then\n+local M = {}\n+\n+--- Show ECA messages using snacks.picker\n+function M.show()\n+ local has_snacks, picker = pcall(require, "snacks.picker")\n+ if not has_snacks then\n+ vim.notify("snacks.picker is not available", vim.log.levels.ERROR)\n+ return\n+ end\n+\n Snacks.picker(\n ---@type snacks.picker.Config\n {\n@@ -29, 3 +37, 5 @@\n }\n )\n end\n+\n+return M',
23+
linesAdded = 12,
24+
linesRemoved = 10,
25+
path = "/Users/tgeorge/git/eca-nvim/hack/messages.lua",
26+
type = "fileChange",
27+
},
28+
id = "toolu_01KAVb3qpJDcSnbnJmpUndQF",
29+
manualApproval = true,
30+
name = "eca_edit_file",
31+
origin = "native",
32+
summary = "Editting file",
33+
type = "toolCallRun",
34+
}
35+
36+
stubs.mcp = {
37+
arguments = {
38+
content = 'return "hello world"',
39+
path = "/Users/tgeorge/git/eca-nvim/hack/test_mcp_write_file.lua",
40+
},
41+
id = "toolu_01B8xcb7csLRHvqrnAZTgzPi",
42+
manualApproval = true,
43+
name = "write_file",
44+
origin = "mcp",
45+
type = "toolCallRun",
46+
}
47+
48+
return stubs

tests/test_approve.lua

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
local MiniTest = require("mini.test")
2+
local eq = MiniTest.expect.equality
3+
local child = MiniTest.new_child_neovim()
4+
local stubs = require("tests.stubs.tool_calls")
5+
6+
local T = MiniTest.new_set({
7+
hooks = {
8+
pre_case = function()
9+
child.restart({ "-u", "scripts/minimal_init.lua" })
10+
child.lua([[
11+
_G.notifications = {}
12+
_G.on_accept = function() table.insert(_G.notifications, "accept") end
13+
_G.on_reject = function() table.insert(_G.notifications, "reject") end
14+
]])
15+
end,
16+
post_once = child.stop,
17+
},
18+
})
19+
20+
T["preview lines"] = function()
21+
local test_cases = {
22+
{
23+
input = stubs.read_file,
24+
want = {
25+
"Summary: Reading file messages.lua",
26+
"Tool Name: eca_read_file",
27+
"Tool Type: native",
28+
"Tool Arguments: ",
29+
"{",
30+
' path = "/Users/tgeorge/git/eca-nvim/hack/messages.lua"',
31+
"}",
32+
},
33+
},
34+
{
35+
input = stubs.edit_file,
36+
want = {
37+
"/Users/tgeorge/git/eca-nvim/hack/messages.lua",
38+
"@@ -1, 5 +1, 13 @@",
39+
'-local has_snacks, picker = pcall(require, "snacks.picker")',
40+
"-if has_snacks then",
41+
"+local M = {}",
42+
"+",
43+
"+--- Show ECA messages using snacks.picker",
44+
"+function M.show()",
45+
'+ local has_snacks, picker = pcall(require, "snacks.picker")',
46+
"+ if not has_snacks then",
47+
'+ vim.notify("snacks.picker is not available", vim.log.levels.ERROR)',
48+
"+ return",
49+
"+ end",
50+
"+",
51+
" Snacks.picker(",
52+
" ---@type snacks.picker.Config",
53+
" {",
54+
"@@ -29, 3 +37, 5 @@",
55+
" }",
56+
" )",
57+
" end",
58+
"+",
59+
"+return M",
60+
},
61+
},
62+
{
63+
input = stubs.mcp,
64+
want = {
65+
"Tool Name: write_file",
66+
"Tool Type: mcp",
67+
"Tool Arguments: ",
68+
"{",
69+
" content = 'return \"hello world\"',",
70+
' path = "/Users/tgeorge/git/eca-nvim/hack/test_mcp_write_file.lua"',
71+
"}",
72+
},
73+
},
74+
}
75+
for _, test_case in pairs(test_cases) do
76+
local got = child.lua_get('require("eca.approve").get_preview_lines(...)', { test_case.input })
77+
eq(got, test_case.want)
78+
end
79+
end
80+
81+
T["tool approval calls callback"] = function()
82+
child.lua("_G.tool_call = " .. vim.inspect(stubs.read_file))
83+
child.lua('require("eca.approve").approve_tool_call(_G.tool_call, _G.on_accept, _G.on_reject)')
84+
child.type_keys("y")
85+
eq(child.lua_get("_G.notifications"), { "accept" })
86+
child.lua('require("eca.approve").approve_tool_call(_G.tool_call, _G.on_accept, _G.on_reject)')
87+
child.type_keys("n")
88+
eq(child.lua_get("_G.notifications"), { "accept", "reject" })
89+
end
90+
91+
return T

0 commit comments

Comments
 (0)