Skip to content

Commit b3a8fa2

Browse files
authored
feat(cli, nvim): Use chunk_id to deduplicate chunks in codecompanion tool (#183)
* feat(cli): Return `chunk_id` in structured query result output * Auto generate docs * feat(nvim): Deduplicate tool results using chunk_id (wip) * feat(nvim): Deduplicate tool results using in-house result tracker * fix(nvim): make sure `no_duplicate` option is effective * refactor(nvim): refactoring and cleanup --------- Co-authored-by: Davidyz <[email protected]>
1 parent fc4391b commit b3a8fa2

File tree

7 files changed

+101
-8
lines changed

7 files changed

+101
-8
lines changed

doc/VectorCode-cli.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,17 +660,21 @@ If you used `--include chunk path` parameters, the array will look like this:
660660
"chunk": "foo",
661661
"start_line": 1,
662662
"end_line": 1,
663+
"chunk_id": "chunk_id_1"
663664
},
664665
{
665666
"path": "path_to_another_file.py",
666667
"chunk": "bar",
667668
"start_line": 1,
668669
"end_line": 1,
670+
"chunk_id": "chunk_id_2"
669671
}
670672
]
671673
<
672674

673-
Keep in mind that both `start_line` and `end_line` are inclusive.
675+
Keep in mind that both `start_line` and `end_line` are inclusive. The
676+
`chunk_id` is a random string that can be used as a unique identifier to
677+
distinguish between chunks. These are the same IDs used in the database.
674678

675679

676680
VECTORCODE VECTORISE

docs/cli.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -595,16 +595,20 @@ If you used `--include chunk path` parameters, the array will look like this:
595595
"chunk": "foo",
596596
"start_line": 1,
597597
"end_line": 1,
598+
"chunk_id": "chunk_id_1"
598599
},
599600
{
600601
"path": "path_to_another_file.py",
601602
"chunk": "bar",
602603
"start_line": 1,
603604
"end_line": 1,
605+
"chunk_id": "chunk_id_2"
604606
}
605607
]
606608
```
607-
Keep in mind that both `start_line` and `end_line` are inclusive.
609+
Keep in mind that both `start_line` and `end_line` are inclusive. The `chunk_id`
610+
is a random string that can be used as a unique identifier to distinguish
611+
between chunks. These are the same IDs used in the database.
608612
609613
#### `vectorcode vectorise`
610614
The output is in JSON format. It contains a dictionary with the following fields:

lua/vectorcode/integrations/codecompanion/common.lua

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
---@module "codecompanion"
2+
13
local job_runner
24
local vc_config = require("vectorcode.config")
35
local notify_opts = vc_config.notify_opts
@@ -17,8 +19,10 @@ local default_ls_options = {}
1719
---@type VectorCode.CodeCompanion.VectoriseToolOpts
1820
local default_vectorise_options = {}
1921

22+
local TOOL_RESULT_SOURCE = "VectorCodeToolResult"
23+
2024
return {
21-
tool_result_source = "VectorCodeToolResult",
25+
tool_result_source = TOOL_RESULT_SOURCE,
2226
---@param t table|string
2327
---@return string
2428
flatten_table_to_string = function(t)
@@ -122,6 +126,7 @@ return {
122126
end
123127
return llm_message
124128
end,
129+
125130
---@param use_lsp boolean
126131
---@return VectorCode.JobRunner
127132
initialise_runner = function(use_lsp)

lua/vectorcode/integrations/codecompanion/query_tool.lua

Lines changed: 79 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,77 @@ local job_runner = nil
99

1010
---@alias QueryToolArgs { project_root:string, count: integer, query: string[] }
1111

12+
---@alias chat_id integer
13+
---@alias result_id string
14+
---@type <chat_id: result_id>
15+
local result_tracker = {}
16+
17+
---@param results VectorCode.QueryResult[]
18+
---@param chat CodeCompanion.Chat
19+
---@return VectorCode.QueryResult[]
20+
local filter_results = function(results, chat)
21+
local existing_refs = chat.refs or {}
22+
23+
existing_refs = vim
24+
.iter(existing_refs)
25+
:filter(
26+
---@param ref CodeCompanion.Chat.Ref
27+
function(ref)
28+
return ref.source == cc_common.tool_result_source or ref.path or ref.bufnr
29+
end
30+
)
31+
:map(
32+
---@param ref CodeCompanion.Chat.Ref
33+
function(ref)
34+
if ref.source == cc_common.tool_result_source then
35+
return ref.id
36+
elseif ref.path then
37+
return ref.path
38+
elseif ref.bufnr then
39+
return vim.api.nvim_buf_get_name(ref.bufnr)
40+
end
41+
end
42+
)
43+
:totable()
44+
45+
---@type VectorCode.QueryResult[]
46+
local filtered_results = vim
47+
.iter(results)
48+
:filter(
49+
---@param res VectorCode.QueryResult
50+
function(res)
51+
-- return true if res should be kept
52+
if res.chunk then
53+
if res.chunk_id == nil then
54+
-- no chunk_id, always include
55+
return true
56+
end
57+
if
58+
result_tracker[chat.id] ~= nil and result_tracker[chat.id][res.chunk_id]
59+
then
60+
return false
61+
end
62+
return not vim.tbl_contains(existing_refs, res.chunk_id)
63+
else
64+
if result_tracker[chat.id] ~= nil and result_tracker[chat.id][res.path] then
65+
return false
66+
end
67+
return not vim.tbl_contains(existing_refs, res.path)
68+
end
69+
end
70+
)
71+
:totable()
72+
73+
for _, res in pairs(filtered_results) do
74+
if result_tracker[chat.id] == nil then
75+
result_tracker[chat.id] = {}
76+
end
77+
result_tracker[chat.id][res.chunk_id or res.path] = true
78+
end
79+
80+
return filtered_results
81+
end
82+
1283
---@param opts VectorCode.CodeCompanion.QueryToolOpts?
1384
---@return CodeCompanion.Agent.Tool
1485
return check_cli_wrap(function(opts)
@@ -148,6 +219,7 @@ You may include multiple keywords in the command.
148219
description = [[
149220
Query messages used for the search. They should also contain relevant keywords.
150221
For example, you should include `parameter`, `arguments` and `return value` for the query `function`.
222+
If a query returned empty or repeated results, you should avoid using these query keywords, unless the user instructed otherwise.
151223
]],
152224
},
153225
count = {
@@ -219,6 +291,9 @@ For example, you should include `parameter`, `arguments` and `return value` for
219291
if opts.max_num > 0 then
220292
max_result = math.min(opts.max_num or 1, max_result)
221293
end
294+
if opts.no_duplicate then
295+
stdout = filter_results(stdout, agent.chat)
296+
end
222297
for i, file in pairs(stdout) do
223298
if i <= max_result then
224299
if i == 1 then
@@ -240,14 +315,14 @@ For example, you should include `parameter`, `arguments` and `return value` for
240315
user_message
241316
)
242317
if not opts.chunk_mode then
243-
-- skip referencing because there will be multiple chunks with the same path (id).
244-
-- TODO: figure out a way to deduplicate.
245-
agent.chat.references:add({
318+
-- only add to reference if running in full document mode
319+
local ref = {
246320
source = cc_common.tool_result_source,
247321
id = file.path,
248322
path = file.path,
249323
opts = { visible = false },
250-
})
324+
}
325+
agent.chat.references:add(ref)
251326
end
252327
end
253328
end

lua/vectorcode/types.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
---@field chunk string?
66
---@field start_line integer?
77
---@field end_line integer?
8+
---@field chunk_id string?
89

910
---@class VectorCode.LsResult
1011
---@field project-root string

src/vectorcode/subcommands/query/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,10 @@ async def build_query_results(
114114
assert chunk_texts is not None, (
115115
"QueryResult does not contain `documents`!"
116116
)
117-
full_result: dict[str, str | int] = {"chunk": str(chunk_texts[0])}
117+
full_result: dict[str, str | int] = {
118+
"chunk": str(chunk_texts[0]),
119+
"chunk_id": identifier,
120+
}
118121
if meta[0].get("start") is not None and meta[0].get("end") is not None:
119122
path = str(meta[0].get("path"))
120123
with open(path) as fin:

tests/subcommands/query/test_query.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ async def test_build_query_results_chunk_mode_success(mock_collection, mock_conf
173173
"chunk": expected_chunk_content,
174174
"start_line": start_line,
175175
"end_line": end_line,
176+
"chunk_id": identifier,
176177
}
177178

178179
assert results[0] == expected_full_result

0 commit comments

Comments
 (0)