Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions lua/copilot-lsp/lcs.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
local M = {}

---@class copilotlsp.lcs.Edit
---@field kind copilotlsp.lcs.EditKind
---@field text string

---@enum copilotlsp.lcs.EditKind
M.edit_kind = {
addition = "addition",
removal = "removal",
unchanged = "unchanged",
}

--- Computes the Long Common Subequence table.
--- Reference: [https://en.wikipedia.org/wiki/Longest_common_subsequence#Computing_the_length_of_the_LCS]
---@param source string
---@param target string
function M.generate_table(source, target)
local n = #source + 1
local m = #target + 1

---@type integer[][]
local lcs = {}
for i = 1, n do
lcs[i] = {}
for j = 1, m do
lcs[i][j] = 0
end
end

for i = 2, n do
for j = 2, m do
if source:byte(i - 1) == target:byte(j - 1) then
lcs[i][j] = 1 + lcs[i - 1][j - 1]
else
lcs[i][j] = math.max(lcs[i - 1][j], lcs[i][j - 1])
end
end
end
return lcs
end

---@generic T
---@param tbl T[]
---@return T[]
local function reverse_table(tbl)
local ret = {}
for i = #tbl, 1, -1 do
table.insert(ret, tbl[i])
end
return ret
end

--- Calculates a diff between two strings using LCS
---@param source string
---@param target string
---@return copilotlsp.lcs.Edit[]
function M.diff(source, target)
local src_idx, trt_idx = #source + 1, #target + 1
local lcs

---@type copilotlsp.lcs.Edit[]
local edits = {}
local edit_idx = 1

local edit_kind = M.edit_kind

while src_idx > 1 or trt_idx > 1 do
if src_idx == 1 then
trt_idx = trt_idx - 1
edits[edit_idx] = {
kind = edit_kind.addition,
text = string.char(string.byte(target, trt_idx)),
}
elseif trt_idx == 1 then
src_idx = src_idx - 1
edits[edit_idx] = {
kind = edit_kind.removal,
text = string.char(string.byte(source, src_idx)),
}
else
local src_char = string.byte(source, src_idx - 1)
local trt_char = string.byte(target, trt_idx - 1)

if src_char == trt_char then
src_idx, trt_idx = src_idx - 1, trt_idx - 1
edits[edit_idx] = {
kind = edit_kind.unchanged,
text = string.char(src_char),
}
else
lcs = lcs or M.generate_table(source, target)
if lcs[src_idx - 1][trt_idx] <= lcs[src_idx][trt_idx - 1] then
trt_idx = trt_idx - 1
edits[edit_idx] = {
kind = edit_kind.addition,
text = string.char(trt_char),
}
else
src_idx = src_idx - 1
edits[edit_idx] = {
kind = edit_kind.removal,
text = string.char(src_char),
}
end
end
end
edit_idx = edit_idx + 1
end

return reverse_table(edits)
end

---@param edits copilotlsp.lcs.Edit[]
---@param line integer
---@param character integer
---@return lsp.TextEdit[]
function M.to_lsp_edits(edits, line, character)
local function advance_cursor(edit)
if edit.text == "\n" then
line = line + 1
character = 0
else
character = character + 1
end
end

---@type lsp.TextEdit[]
local lsp_edits = {}
local i = 1
while i < #edits do
-- Skip all unchanged edits and advance cursor
while i < #edits and edits[i].kind == M.edit_kind.unchanged do
advance_cursor(edits[i])
i = i + 1
end

-- No more edits to compute
if i >= #edits then
break
end

local new_text = ""
local start_line, start_character = line, character

-- Collect consecutive additions and removals
while i < #edits and edits[i].kind ~= M.edit_kind.unchanged do
if edits[i].kind == M.edit_kind.addition then
new_text = new_text .. edits[i].text
elseif edits[i].kind == M.edit_kind.removal then
advance_cursor(edits[i])
else
error("unexcepted edit kind " .. edits[i].kind)
end
i = i + 1
end

---@type lsp.TextEdit
local lsp_edit = {
newText = new_text,
range = {
start = {
line = start_line,
character = start_character,
},
["end"] = {
line = line,
character = character,
},
},
}

table.insert(lsp_edits, lsp_edit)
end

return lsp_edits
end

return M
88 changes: 88 additions & 0 deletions lua/copilot-lsp/mimimal_edits.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
local lcs = require("copilot-lsp.lcs")
local M = {}

---@param lines string[]
---@param range lsp.Range
local function extract_lines_from_range(lines, range)
local start_row = range.start.line + 1
local start_col = range.start.character + 1
local end_row = range["end"].line + 1
local end_col = range["end"].character + 1

local source_lines = {}
-- Loop through the zero-indexed range [source_start_row, source_end_row)
for i = start_row, end_row do
local line = lines[i]

if i == start_row then
line = line:sub(start_col, -1)
elseif i == end_row then
line = line:sub(1, end_col - 1)
end

-- strip CR characters when neovim fails to identify the correct file format
if vim.endswith(line, "\r") then
table.insert(source_lines, line:sub(1, -2))
else
table.insert(source_lines, line)
end
end

return source_lines
end

---@param source_buf string[]
---@param target_edit lsp.TextEdit
---@return lsp.TextEdit[]
function M.compute_minimal_edits(source_buf, target_edit)
local source_lines = extract_lines_from_range(source_buf, target_edit.range)
local target_lines = vim.split(target_edit.newText, "\r?\n")

local source_text = table.concat(source_lines, "\n")
local target_text = table.concat(target_lines, "\n")

local indices = vim.diff(source_text, target_text, {
algorithm = "histogram",
result_type = "indices",
})
assert(type(indices) == "table")

---@type lsp.TextEdit[]
local edits = {}

for _, idx in ipairs(indices) do
local source_line_start, source_line_count, target_line_start, target_line_count = unpack(idx)
local source_line_end = source_line_start + source_line_count - 1
local target_line_end = target_line_start + target_line_count - 1

local source = table.concat(source_lines, "\n", source_line_start, source_line_end)
local target = table.concat(target_lines, "\n", target_line_start, target_line_end)

local text_edits = lcs.to_lsp_edits(
lcs.diff(source, target),
source_line_start + target_edit.range.start.line - 1,
target_edit.range.start.character
)

vim.list_extend(edits, text_edits)
end

local contains_non_whitespace_edit = vim.iter(edits):any(function(edit)
return edit.newText:find("%S") ~= nil
end)

-- Diff the whole text if we encounter a non whitespace character in the edit.
-- This might happen when the formatted document deletes many lines
-- and `vim.diff` split those deletions into multiple hunks.
if contains_non_whitespace_edit then
edits = lcs.to_lsp_edits(
lcs.diff(source_text, target_text),
target_edit.range.start.line,
target_edit.range.start.character
)
end

return edits
end

return M
20 changes: 20 additions & 0 deletions lua/copilot-lsp/nes/init.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
local errs = require("copilot-lsp.errors")
local nes_ui = require("copilot-lsp.nes.ui")
local utils = require("copilot-lsp.util")
local me = require("copilot-lsp.mimimal_edits")

local M = {}

Expand All @@ -22,6 +23,25 @@ local function handle_nes_response(err, result, ctx)
--- Convert to textEdit fields
edit.newText = edit.text
end
if #result.edits == 1 then
---@type copilotlsp.InlineEdit[]
local min_edits = {}

-- minimise the edits
local source_lines = vim.api.nvim_buf_get_lines(ctx.bufnr, 0, -1, false)
local min_te = me.compute_minimal_edits(source_lines, result.edits[1])
for _, edit in ipairs(min_te) do
local new_edit = {
range = edit.range,
newText = edit.newText,
command = result.edits[1].command,
textDocument = result.edits[1].textDocument,
}
table.insert(min_edits, new_edit)
end

result.edits = min_edits
end
nes_ui._display_next_suggestion(ctx.bufnr, nes_ns, result.edits)
end

Expand Down
Loading