Skip to content

Commit d9e9614

Browse files
committed
add refresh capability and enable toggle
1 parent 3b8eb34 commit d9e9614

File tree

7 files changed

+101
-45
lines changed

7 files changed

+101
-45
lines changed

lua/codecompanion/config.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ If you are providing code changes, use the insert_edit_into_file tool (if availa
478478
},
479479
},
480480
mcp = {
481+
enabled = true,
481482
servers = {},
482483
},
483484
keymaps = {

lua/codecompanion/interactions/chat/init.lua

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,9 @@ function Chat.new(args)
526526

527527
self:update_metadata()
528528

529-
require("codecompanion.interactions.chat.mcp").start_servers()
529+
if config.interactions.chat.mcp.enabled then
530+
require("codecompanion.interactions.chat.mcp").start_servers()
531+
end
530532

531533
-- Likely this hasn't been set by the time the user opens the chat buffer
532534
if not _G.codecompanion_current_context then

lua/codecompanion/interactions/chat/mcp/client.lua

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,21 @@ StdioTransport.static.methods = {
6565
schedule_wrap = { default = vim.schedule_wrap },
6666
}
6767

68+
---@class CodeCompanion.MCP.StdioTransportArgs
69+
---@field name string
70+
---@field cfg CodeCompanion.MCP.ServerConfig
71+
---@field methods? table<string, function> Optional method overrides for testing
72+
6873
---Create a new StdioTransport for the given server configuration.
69-
---@param name string
70-
---@param cfg CodeCompanion.MCP.ServerConfig
71-
---@param methods? table<string, function> Optional method overrides for testing
74+
---@param args CodeCompanion.MCP.StdioTransportArgs
7275
---@return CodeCompanion.MCP.StdioTransport
73-
function StdioTransport:new(name, cfg, methods)
76+
function StdioTransport.new(args)
7477
return setmetatable({
75-
name = name,
76-
cmd = cfg.cmd,
77-
env = cfg.env,
78-
methods = transform_static_methods(StdioTransport, methods),
79-
}, self)
78+
name = args.name,
79+
cmd = args.cfg.cmd,
80+
env = args.cfg.env,
81+
methods = transform_static_methods(StdioTransport, args.methods),
82+
}, StdioTransport)
8083
end
8184

8285
---Start the underlying process and attach stdout/stderr callbacks.
@@ -245,8 +248,8 @@ Client.__index = Client
245248
Client.static = {}
246249
Client.static.methods = {
247250
new_transport = {
248-
default = function(name, cfg, methods)
249-
return StdioTransport:new(name, cfg, methods)
251+
default = function(args)
252+
return StdioTransport.new(args)
250253
end,
251254
},
252255
json_decode = { default = vim.json.decode },
@@ -255,25 +258,36 @@ Client.static.methods = {
255258
defer_fn = { default = vim.defer_fn },
256259
}
257260

261+
---@class CodeCompanion.MCP.ClientArgs
262+
---@field name string
263+
---@field cfg CodeCompanion.MCP.ServerConfig
264+
---@field methods? table<string, function> Optional method overrides for testing
265+
258266
---Create a new MCP client instance bound to the provided server configuration.
259-
---@param name string
260-
---@param cfg CodeCompanion.MCP.ServerConfig
261-
---@param methods? table<string, function> Optional method overrides for testing
267+
---@param args CodeCompanion.MCP.ClientArgs
262268
---@return CodeCompanion.MCP.Client
263-
function Client:new(name, cfg, methods)
264-
local static_methods = transform_static_methods(Client, methods)
265-
return setmetatable({
266-
name = name,
267-
cfg = cfg,
269+
function Client.new(args)
270+
local static_methods = transform_static_methods(Client, args.methods)
271+
local self = setmetatable({
272+
name = args.name,
273+
cfg = args.cfg,
268274
ready = false,
269-
transport = static_methods.new_transport(name, cfg, methods),
275+
transport = static_methods.new_transport({ name = args.name, cfg = args.cfg, methods = args.methods }),
270276
resp_handlers = {},
271-
server_request_handlers = {
272-
["ping"] = self._handle_server_ping,
273-
["roots/list"] = self._handler_server_roots_list,
274-
},
277+
server_request_handlers = {},
275278
methods = static_methods,
276-
}, self)
279+
}, Client)
280+
281+
self.server_request_handlers = {
282+
["ping"] = function()
283+
return self:_handle_server_ping()
284+
end,
285+
["roots/list"] = function()
286+
return self:_handle_server_roots_list()
287+
end,
288+
}
289+
290+
return self
277291
end
278292

279293
---Start the client.
@@ -531,7 +545,7 @@ end
531545

532546
---Handler for 'roots/list' server requests.
533547
---@return "result" | "error", table
534-
function Client:_handler_server_roots_list()
548+
function Client:_handle_server_roots_list()
535549
if not self.cfg.roots then
536550
return "error", { code = CONSTANTS.JSONRPC.ERROR_METHOD_NOT_FOUND, message = "roots capability not enabled" }
537551
end

lua/codecompanion/interactions/chat/mcp/init.lua

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function M.start_servers()
3030
local mcp_cfg = require("codecompanion.config").interactions.chat.mcp
3131
for name, cfg in pairs(mcp_cfg.servers or {}) do
3232
if not clients[name] then
33-
local client = Client:new(name, cfg)
33+
local client = Client.new({ name = name, cfg = cfg })
3434
clients[name] = client
3535
end
3636
end
@@ -58,6 +58,35 @@ function M.stop_servers()
5858
clients = {}
5959
end
6060

61+
---Restart all MCP servers
62+
---@return nil
63+
function M.restart_servers()
64+
M.stop_servers()
65+
M.start_servers()
66+
end
67+
68+
---Refresh configuration and restart servers
69+
---This allows users to update their MCP config and apply changes without restarting Neovim
70+
---@return nil
71+
function M.refresh()
72+
M.stop_servers()
73+
74+
-- Clear cached tool groups and tools from config
75+
local chat_tools = require("codecompanion.config").interactions.chat.tools
76+
for name, _ in pairs(chat_tools.groups) do
77+
if name:match("^mcp:") then
78+
chat_tools.groups[name] = nil
79+
end
80+
end
81+
for name, tool in pairs(chat_tools) do
82+
if type(tool) == "table" and vim.tbl_get(tool, "opts", "_mcp_info") then
83+
chat_tools[name] = nil
84+
end
85+
end
86+
87+
M.start_servers()
88+
end
89+
6190
---Get status of all MCP servers
6291
---@return table<string, { ready: boolean, tool_count: number, started: boolean }>
6392
function M.get_status()

tests/config.lua

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,9 @@ return {
357357
},
358358
},
359359
},
360-
mcp = {},
360+
mcp = {
361+
enabled = true,
362+
},
361363
opts = {
362364
blank_prompt = "",
363365
debounce = 0,

tests/interactions/chat/mcp/test_mcp_client.lua

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ local T = MiniTest.new_set({
4848
callback = function() tools_loaded = true end,
4949
})
5050
51-
CLI = Client:new("testMcp", { cmd = { "test-mcp" } }, { new_transport = mock_new_transport })
51+
CLI = Client.new({ name = "testMcp", cfg = { cmd = { "test-mcp" } }, methods = { new_transport = mock_new_transport } })
5252
CLI:start()
5353
vim.wait(1000, function() return tools_loaded end)
5454
end
@@ -84,7 +84,7 @@ T["MCP Client"]["start() starts and initializes the client once"] = function()
8484
TRANSPORT:expect_jsonrpc_notify("notifications/initialized", function() end)
8585
8686
setup_tool_list()
87-
CLI = Client:new("testMcp", { cmd = { "test-mcp" } }, { new_transport = mock_new_transport })
87+
CLI = Client.new({ name = "testMcp", cfg = { cmd = { "test-mcp" } }, methods = { new_transport = mock_new_transport } })
8888
CLI:start()
8989
CLI:start() -- repeated call should be no-op
9090
CLI:start()
@@ -290,10 +290,14 @@ T["MCP Client"]["roots capability is declared when roots config is provided"] =
290290
{ uri = "file:///home/user/project2", name = "Project 2" },
291291
}
292292
293-
CLI = Client:new("testMcp", {
294-
cmd = { "test-mcp" },
295-
roots = function() return roots end,
296-
}, { new_transport = mock_new_transport })
293+
CLI = Client.new({
294+
name = "testMcp",
295+
cfg = {
296+
cmd = { "test-mcp" },
297+
roots = function() return roots end,
298+
},
299+
methods = { new_transport = mock_new_transport },
300+
})
297301
CLI:start()
298302
vim.wait(1000, function() return CLI.ready end)
299303
@@ -331,13 +335,17 @@ T["MCP Client"]["roots list changed notification is sent when roots change"] = f
331335
local current_roots
332336
333337
local notify_roots_list_changed
334-
CLI = Client:new("testMcp", {
335-
cmd = { "test-mcp" },
336-
roots = function() return current_roots end,
337-
register_roots_list_changed = function(notify)
338-
notify_roots_list_changed = notify
339-
end,
340-
}, { new_transport = mock_new_transport })
338+
CLI = Client.new({
339+
name = "testMcp",
340+
cfg = {
341+
cmd = { "test-mcp" },
342+
roots = function() return current_roots end,
343+
register_roots_list_changed = function(notify)
344+
notify_roots_list_changed = notify
345+
end,
346+
},
347+
methods = { new_transport = mock_new_transport },
348+
})
341349
CLI:start()
342350
vim.wait(1000, function() return CLI.ready end)
343351
@@ -370,7 +378,7 @@ T["MCP Client"]["transport closed automatically on initialization failure"] = fu
370378
return "error", { code = -32603, message = "Initialization failed" }
371379
end)
372380
373-
CLI = Client:new("testMcp", { cmd = { "test-mcp" } }, { new_transport = mock_new_transport })
381+
CLI = Client.new({ name = "testMcp", cfg = { cmd = { "test-mcp" } }, methods = { new_transport = mock_new_transport } })
374382
CLI:start()
375383
vim.wait(1000, function() return TRANSPORT:all_handlers_consumed() end)
376384
vim.wait(1000, function() return not CLI.ready end)

tests/interactions/chat/mcp/test_mcp_tools.lua

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ local T = MiniTest.new_set({
3030
return not vim.startswith(tool.name, "math_")
3131
end):totable()
3232
33-
Client.static.methods.new_transport.default = function(name, cfg)
33+
Client.static.methods.new_transport.default = function(args)
3434
local transport
3535
local tools
36-
if cfg.cmd[1] == "math_mcp" then
36+
if args.cfg.cmd[1] == "math_mcp" then
3737
transport = MATH_MCP_TRANSPORT
3838
tools = MATH_MCP_TOOLS
3939
else

0 commit comments

Comments
 (0)