Skip to content

Commit c7e54d2

Browse files
authored
perf(chat): #552 improve settings tree-sitter query (#708)
1 parent 96570e6 commit c7e54d2

File tree

1 file changed

+24
-13
lines changed
  • lua/codecompanion/strategies/chat

1 file changed

+24
-13
lines changed

lua/codecompanion/strategies/chat/init.lua

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ local util = require("codecompanion.utils")
1111
local yaml = require("codecompanion.utils.yaml")
1212

1313
local api = vim.api
14+
local get_node_text = vim.treesitter.get_node_text --[[@type function]]
15+
local get_query = vim.treesitter.query.get --[[@type function]]
1416

1517
local CONSTANTS = {
1618
AUTOCMD_GROUP = "codecompanion.chat",
@@ -32,7 +34,8 @@ local function make_id(val)
3234
return hash.hash(val)
3335
end
3436

35-
local _cached_settings
37+
local _cached_settings = {}
38+
local _yaml_parser
3639

3740
---Parse the chat buffer for settings
3841
---@param bufnr integer
@@ -47,21 +50,29 @@ local function ts_parse_settings(bufnr, adapter)
4750
if not config.display.chat.show_settings then
4851
if adapter then
4952
_cached_settings[bufnr] = adapter:get_default_settings()
50-
5153
return _cached_settings[bufnr]
5254
end
5355
end
5456

5557
local settings = {}
56-
local parser = vim.treesitter.get_parser(bufnr, "yaml", { ignore_injections = false })
57-
local query = vim.treesitter.query.get("yaml", "chat")
58-
local root = parser:parse()[1]:root()
58+
if not _yaml_parser then
59+
_yaml_parser = vim.treesitter.get_parser(bufnr, "yaml", { ignore_injections = false })
60+
end
61+
62+
local query = get_query("yaml", "chat")
63+
local root = _yaml_parser:parse()[1]:root()
64+
65+
local end_line = -1
66+
if adapter then
67+
-- Account for the two YAML lines and the fact Tree-sitter is 0-indexed
68+
end_line = vim.tbl_count(adapter:get_default_settings()) + 2 - 1
69+
end
5970

60-
for _, matches, _ in query:iter_matches(root, bufnr) do
71+
for _, matches, _ in query:iter_matches(root, bufnr, 0, end_line) do
6172
local nodes = matches[1]
6273
local node = type(nodes) == "table" and nodes[1] or nodes
6374

64-
local value = vim.treesitter.get_node_text(node, bufnr)
75+
local value = get_node_text(node, bufnr)
6576

6677
settings = yaml.decode(value)
6778
break
@@ -81,7 +92,7 @@ end
8192
---@param start_range number
8293
---@return { content: string }
8394
local function ts_parse_messages(chat, role, start_range)
84-
local query = vim.treesitter.query.get("markdown", "chat")
95+
local query = get_query("markdown", "chat")
8596

8697
local tree = chat.parser:parse({ start_range - 1, -1 })[1]
8798
local root = tree:root()
@@ -91,9 +102,9 @@ local function ts_parse_messages(chat, role, start_range)
91102

92103
for id, node in query:iter_captures(root, chat.bufnr, start_range - 1, -1) do
93104
if query.captures[id] == "role" then
94-
last_role = vim.treesitter.get_node_text(node, chat.bufnr)
105+
last_role = get_node_text(node, chat.bufnr)
95106
elseif last_role == chat.ui:format_header(user_role) and query.captures[id] == "content" then
96-
table.insert(content, vim.treesitter.get_node_text(node, chat.bufnr))
107+
table.insert(content, get_node_text(node, chat.bufnr))
97108
end
98109
end
99110

@@ -111,7 +122,7 @@ end
111122
---@return TSNode | nil
112123
local function ts_parse_codeblock(chat, cursor)
113124
local root = chat.parser:parse()[1]:root()
114-
local query = vim.treesitter.query.get("markdown", "chat")
125+
local query = get_query("markdown", "chat")
115126
if query == nil then
116127
return nil
117128
end
@@ -315,7 +326,7 @@ function Chat:set_autocmds()
315326
if errors and node then
316327
for child in node:iter_children() do
317328
assert(child:type() == "block_mapping_pair")
318-
local key = vim.treesitter.get_node_text(child:named_child(0), self.bufnr)
329+
local key = get_node_text(child:named_child(0), self.bufnr)
319330
if errors[key] then
320331
local lnum, col, end_lnum, end_col = child:range()
321332
table.insert(items, {
@@ -350,7 +361,7 @@ function Chat:_get_settings_key(opts)
350361
return
351362
end
352363
local key_node = node:named_child(0)
353-
local key_name = vim.treesitter.get_node_text(key_node, self.bufnr)
364+
local key_name = get_node_text(key_node, self.bufnr)
354365
return key_name, node
355366
end
356367

0 commit comments

Comments
 (0)