Skip to content

Commit 5014b41

Browse files
committed
fix(context): ensure minimum number of search results
Adds MIN_RESULTS constant to guarantee at least 3 results are returned when searching for related code context, even if they don't meet the similarity threshold. This improves the quality of contextual responses by preventing cases where too few results would be returned. Also adds score normalization for symbol-based search and debug logging for generated messages.
1 parent 57d7736 commit 5014b41

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

lua/CopilotChat/client.lua

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,8 @@ function Client:ask(prompt, opts)
427427
end
428428
end
429429

430+
log.debug('Generated messages: ', #generated_messages)
431+
430432
local last_message = nil
431433
local errored = false
432434
local finished = false

lua/CopilotChat/context.lua

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ local OFF_SIDE_RULE_LANGUAGES = {
6969
local MIN_SYMBOL_SIMILARITY = 0.3
7070
local MIN_SEMANTIC_SIMILARITY = 0.4
7171
local MULTI_FILE_THRESHOLD = 5
72+
local MIN_RESULTS = 3
7273
local MAX_FILES = 2500
7374

7475
--- Compute the cosine similarity between two vectors
@@ -103,16 +104,22 @@ local function data_ranked_by_relatedness(query, data, min_similarity)
103104
local results = {}
104105
for _, item in ipairs(data) do
105106
local similarity = spatial_distance_cosine(item.embedding, query.embedding, item.score)
106-
if similarity >= min_similarity then
107-
table.insert(results, vim.tbl_extend('force', item, { score = similarity }))
108-
end
107+
table.insert(results, vim.tbl_extend('force', item, { score = similarity }))
109108
end
110109

111110
table.sort(results, function(a, b)
112111
return a.score > b.score
113112
end)
114113

115-
return results
114+
-- Take top MAX_RESULTS items that meet threshold, or at least MIN_RESULTS items
115+
local filtered = {}
116+
for i, result in ipairs(results) do
117+
if (result.score >= min_similarity) or (i <= MIN_RESULTS) then
118+
table.insert(filtered, result)
119+
end
120+
end
121+
122+
return filtered
116123
end
117124

118125
-- Create trigrams from text (e.g., "hello" -> {"hel", "ell", "llo"})
@@ -168,7 +175,7 @@ local function data_ranked_by_symbols(query, data, min_similarity)
168175
local max_score = 0
169176

170177
for _, entry in ipairs(data) do
171-
local score = 0
178+
local score = entry.score or 0
172179
local basename = vim.fn.fnamemodify(entry.filename, ':t'):gsub('%..*$', '')
173180

174181
-- Get trigrams for basename and compound version
@@ -201,19 +208,24 @@ local function data_ranked_by_symbols(query, data, min_similarity)
201208
end
202209
end
203210

204-
-- Normalize and filter results
205-
local filtered_results = {}
211+
-- Normalize scores
206212
for _, result in ipairs(results) do
207213
result.score = result.score / max_score
208-
if result.score >= min_similarity then
209-
table.insert(filtered_results, result)
210-
end
211214
end
212215

213-
table.sort(filtered_results, function(a, b)
216+
-- Sort by score first
217+
table.sort(results, function(a, b)
214218
return a.score > b.score
215219
end)
216220

221+
-- Filter results while preserving top scores
222+
local filtered_results = {}
223+
for i, result in ipairs(results) do
224+
if (result.score >= min_similarity) or (i <= MIN_RESULTS) then
225+
table.insert(filtered_results, result)
226+
end
227+
end
228+
217229
return filtered_results
218230
end
219231

0 commit comments

Comments
 (0)