diff --git a/lua/copilot-lsp/lcs.lua b/lua/copilot-lsp/lcs.lua new file mode 100644 index 0000000..b7b8a23 --- /dev/null +++ b/lua/copilot-lsp/lcs.lua @@ -0,0 +1,179 @@ +local M = {} + +---@class copilotlsp.lcs.Edit +---@field kind copilotlsp.lcs.EditKind +---@field text string + +---@enum copilotlsp.lcs.EditKind +M.edit_kind = { + addition = "addition", + removal = "removal", + unchanged = "unchanged", +} + +--- Computes the Long Common Subequence table. +--- Reference: [https://en.wikipedia.org/wiki/Longest_common_subsequence#Computing_the_length_of_the_LCS] +---@param source string +---@param target string +function M.generate_table(source, target) + local n = #source + 1 + local m = #target + 1 + + ---@type integer[][] + local lcs = {} + for i = 1, n do + lcs[i] = {} + for j = 1, m do + lcs[i][j] = 0 + end + end + + for i = 2, n do + for j = 2, m do + if source:byte(i - 1) == target:byte(j - 1) then + lcs[i][j] = 1 + lcs[i - 1][j - 1] + else + lcs[i][j] = math.max(lcs[i - 1][j], lcs[i][j - 1]) + end + end + end + return lcs +end + +---@generic T +---@param tbl T[] +---@return T[] +local function reverse_table(tbl) + local ret = {} + for i = #tbl, 1, -1 do + table.insert(ret, tbl[i]) + end + return ret +end + +--- Calculates a diff between two strings using LCS +---@param source string +---@param target string +---@return copilotlsp.lcs.Edit[] +function M.diff(source, target) + local src_idx, trt_idx = #source + 1, #target + 1 + local lcs + + ---@type copilotlsp.lcs.Edit[] + local edits = {} + local edit_idx = 1 + + local edit_kind = M.edit_kind + + while src_idx > 1 or trt_idx > 1 do + if src_idx == 1 then + trt_idx = trt_idx - 1 + edits[edit_idx] = { + kind = edit_kind.addition, + text = string.char(string.byte(target, trt_idx)), + } + elseif trt_idx == 1 then + src_idx = src_idx - 1 + edits[edit_idx] = { + kind = edit_kind.removal, + text = string.char(string.byte(source, src_idx)), + } + else + local src_char = string.byte(source, src_idx - 1) + local trt_char = string.byte(target, trt_idx - 1) + + if src_char == trt_char then + src_idx, trt_idx = src_idx - 1, trt_idx - 1 + edits[edit_idx] = { + kind = edit_kind.unchanged, + text = string.char(src_char), + } + else + lcs = lcs or M.generate_table(source, target) + if lcs[src_idx - 1][trt_idx] <= lcs[src_idx][trt_idx - 1] then + trt_idx = trt_idx - 1 + edits[edit_idx] = { + kind = edit_kind.addition, + text = string.char(trt_char), + } + else + src_idx = src_idx - 1 + edits[edit_idx] = { + kind = edit_kind.removal, + text = string.char(src_char), + } + end + end + end + edit_idx = edit_idx + 1 + end + + return reverse_table(edits) +end + +---@param edits copilotlsp.lcs.Edit[] +---@param line integer +---@param character integer +---@return lsp.TextEdit[] +function M.to_lsp_edits(edits, line, character) + local function advance_cursor(edit) + if edit.text == "\n" then + line = line + 1 + character = 0 + else + character = character + 1 + end + end + + ---@type lsp.TextEdit[] + local lsp_edits = {} + local i = 1 + while i < #edits do + -- Skip all unchanged edits and advance cursor + while i < #edits and edits[i].kind == M.edit_kind.unchanged do + advance_cursor(edits[i]) + i = i + 1 + end + + -- No more edits to compute + if i >= #edits then + break + end + + local new_text = "" + local start_line, start_character = line, character + + -- Collect consecutive additions and removals + while i < #edits and edits[i].kind ~= M.edit_kind.unchanged do + if edits[i].kind == M.edit_kind.addition then + new_text = new_text .. edits[i].text + elseif edits[i].kind == M.edit_kind.removal then + advance_cursor(edits[i]) + else + error("unexcepted edit kind " .. edits[i].kind) + end + i = i + 1 + end + + ---@type lsp.TextEdit + local lsp_edit = { + newText = new_text, + range = { + start = { + line = start_line, + character = start_character, + }, + ["end"] = { + line = line, + character = character, + }, + }, + } + + table.insert(lsp_edits, lsp_edit) + end + + return lsp_edits +end + +return M diff --git a/lua/copilot-lsp/mimimal_edits.lua b/lua/copilot-lsp/mimimal_edits.lua new file mode 100644 index 0000000..a6be47f --- /dev/null +++ b/lua/copilot-lsp/mimimal_edits.lua @@ -0,0 +1,88 @@ +local lcs = require("copilot-lsp.lcs") +local M = {} + +---@param lines string[] +---@param range lsp.Range +local function extract_lines_from_range(lines, range) + local start_row = range.start.line + 1 + local start_col = range.start.character + 1 + local end_row = range["end"].line + 1 + local end_col = range["end"].character + 1 + + local source_lines = {} + -- Loop through the zero-indexed range [source_start_row, source_end_row) + for i = start_row, end_row do + local line = lines[i] + + if i == start_row then + line = line:sub(start_col, -1) + elseif i == end_row then + line = line:sub(1, end_col - 1) + end + + -- strip CR characters when neovim fails to identify the correct file format + if vim.endswith(line, "\r") then + table.insert(source_lines, line:sub(1, -2)) + else + table.insert(source_lines, line) + end + end + + return source_lines +end + +---@param source_buf string[] +---@param target_edit lsp.TextEdit +---@return lsp.TextEdit[] +function M.compute_minimal_edits(source_buf, target_edit) + local source_lines = extract_lines_from_range(source_buf, target_edit.range) + local target_lines = vim.split(target_edit.newText, "\r?\n") + + local source_text = table.concat(source_lines, "\n") + local target_text = table.concat(target_lines, "\n") + + local indices = vim.diff(source_text, target_text, { + algorithm = "histogram", + result_type = "indices", + }) + assert(type(indices) == "table") + + ---@type lsp.TextEdit[] + local edits = {} + + for _, idx in ipairs(indices) do + local source_line_start, source_line_count, target_line_start, target_line_count = unpack(idx) + local source_line_end = source_line_start + source_line_count - 1 + local target_line_end = target_line_start + target_line_count - 1 + + local source = table.concat(source_lines, "\n", source_line_start, source_line_end) + local target = table.concat(target_lines, "\n", target_line_start, target_line_end) + + local text_edits = lcs.to_lsp_edits( + lcs.diff(source, target), + source_line_start + target_edit.range.start.line - 1, + target_edit.range.start.character + ) + + vim.list_extend(edits, text_edits) + end + + local contains_non_whitespace_edit = vim.iter(edits):any(function(edit) + return edit.newText:find("%S") ~= nil + end) + + -- Diff the whole text if we encounter a non whitespace character in the edit. + -- This might happen when the formatted document deletes many lines + -- and `vim.diff` split those deletions into multiple hunks. + if contains_non_whitespace_edit then + edits = lcs.to_lsp_edits( + lcs.diff(source_text, target_text), + target_edit.range.start.line, + target_edit.range.start.character + ) + end + + return edits +end + +return M diff --git a/lua/copilot-lsp/nes/init.lua b/lua/copilot-lsp/nes/init.lua index 430466e..e204c7c 100644 --- a/lua/copilot-lsp/nes/init.lua +++ b/lua/copilot-lsp/nes/init.lua @@ -1,6 +1,7 @@ local errs = require("copilot-lsp.errors") local nes_ui = require("copilot-lsp.nes.ui") local utils = require("copilot-lsp.util") +local me = require("copilot-lsp.mimimal_edits") local M = {} @@ -22,6 +23,25 @@ local function handle_nes_response(err, result, ctx) --- Convert to textEdit fields edit.newText = edit.text end + if #result.edits == 1 then + ---@type copilotlsp.InlineEdit[] + local min_edits = {} + + -- minimise the edits + local source_lines = vim.api.nvim_buf_get_lines(ctx.bufnr, 0, -1, false) + local min_te = me.compute_minimal_edits(source_lines, result.edits[1]) + for _, edit in ipairs(min_te) do + local new_edit = { + range = edit.range, + newText = edit.newText, + command = result.edits[1].command, + textDocument = result.edits[1].textDocument, + } + table.insert(min_edits, new_edit) + end + + result.edits = min_edits + end nes_ui._display_next_suggestion(ctx.bufnr, nes_ns, result.edits) end