Skip to content

Commit 461b866

Browse files
committed
fix: range contains(), is_in_range() to use nvim equivalents
1 parent f7122ac commit 461b866

File tree

2 files changed

+209
-32
lines changed

2 files changed

+209
-32
lines changed
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
--- Copy of vim.treesitter._range
2+
--- TODO: replace with `vim.Range` when we drop support for 0.11
3+
local api = vim.api
4+
5+
local M = {}
6+
7+
---@class Range2
8+
---@inlinedoc
9+
---@field [1] integer start row
10+
---@field [2] integer end row
11+
12+
---@class Range4
13+
---@inlinedoc
14+
---@field [1] integer start row
15+
---@field [2] integer start column
16+
---@field [3] integer end row
17+
---@field [4] integer end column
18+
19+
---@class Range6
20+
---@inlinedoc
21+
---@field [1] integer start row
22+
---@field [2] integer start column
23+
---@field [3] integer start bytes
24+
---@field [4] integer end row
25+
---@field [5] integer end column
26+
---@field [6] integer end bytes
27+
28+
---@alias Range Range2|Range4|Range6
29+
30+
---@param a_row integer
31+
---@param a_col integer
32+
---@param b_row integer
33+
---@param b_col integer
34+
---@return integer
35+
--- 1: a > b
36+
--- 0: a == b
37+
--- -1: a < b
38+
local function cmp_pos(a_row, a_col, b_row, b_col)
39+
if a_row == b_row then
40+
if a_col > b_col then
41+
return 1
42+
elseif a_col < b_col then
43+
return -1
44+
else
45+
return 0
46+
end
47+
elseif a_row > b_row then
48+
return 1
49+
end
50+
51+
return -1
52+
end
53+
54+
M.cmp_pos = {
55+
lt = function(...)
56+
return cmp_pos(...) == -1
57+
end,
58+
le = function(...)
59+
return cmp_pos(...) ~= 1
60+
end,
61+
gt = function(...)
62+
return cmp_pos(...) == 1
63+
end,
64+
ge = function(...)
65+
return cmp_pos(...) ~= -1
66+
end,
67+
eq = function(...)
68+
return cmp_pos(...) == 0
69+
end,
70+
ne = function(...)
71+
return cmp_pos(...) ~= 0
72+
end,
73+
}
74+
75+
setmetatable(M.cmp_pos, { __call = cmp_pos })
76+
77+
---Check if a variable is a valid range object
78+
---@param r any
79+
---@return boolean
80+
function M.validate(r)
81+
if type(r) ~= 'table' or #r ~= 6 and #r ~= 4 then
82+
return false
83+
end
84+
85+
for _, e in
86+
ipairs(r --[[@as any[] ]])
87+
do
88+
if type(e) ~= 'number' then
89+
return false
90+
end
91+
end
92+
93+
return true
94+
end
95+
96+
---@param r1 Range
97+
---@param r2 Range
98+
---@return boolean
99+
function M.intercepts(r1, r2)
100+
local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
101+
local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
102+
103+
-- r1 is above r2
104+
if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
105+
return false
106+
end
107+
108+
-- r1 is below r2
109+
if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then
110+
return false
111+
end
112+
113+
return true
114+
end
115+
116+
---@param r1 Range6
117+
---@param r2 Range6
118+
---@return Range6?
119+
function M.intersection(r1, r2)
120+
if not M.intercepts(r1, r2) then
121+
return nil
122+
end
123+
local rs = M.cmp_pos.le(r1[1], r1[2], r2[1], r2[2]) and r2 or r1
124+
local re = M.cmp_pos.ge(r1[4], r1[5], r2[4], r2[5]) and r2 or r1
125+
return { rs[1], rs[2], rs[3], re[4], re[5], re[6] }
126+
end
127+
128+
---@param r Range
129+
---@return integer, integer, integer, integer
130+
function M.unpack4(r)
131+
if #r == 2 then
132+
return r[1], 0, r[2], 0
133+
end
134+
local off_1 = #r == 6 and 1 or 0
135+
return r[1], r[2], r[3 + off_1], r[4 + off_1]
136+
end
137+
138+
---@param r Range6
139+
---@return integer, integer, integer, integer, integer, integer
140+
function M.unpack6(r)
141+
return r[1], r[2], r[3], r[4], r[5], r[6]
142+
end
143+
144+
---@param r1 Range
145+
---@param r2 Range
146+
---@return boolean whether r1 contains r2
147+
function M.contains(r1, r2)
148+
local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
149+
local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
150+
151+
-- start doesn't fit
152+
if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
153+
return false
154+
end
155+
156+
-- end doesn't fit
157+
if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then
158+
return false
159+
end
160+
161+
return true
162+
end
163+
164+
--- @param source integer|string
165+
--- @param index integer
166+
--- @return integer
167+
local function get_offset(source, index)
168+
if index == 0 then
169+
return 0
170+
end
171+
172+
if type(source) == 'number' then
173+
return api.nvim_buf_get_offset(source, index)
174+
end
175+
176+
local byte = 0
177+
local next_offset = source:gmatch('()\n')
178+
local line = 1
179+
while line <= index do
180+
byte = next_offset() --[[@as integer]]
181+
line = line + 1
182+
end
183+
184+
return byte
185+
end
186+
187+
---@param source integer|string
188+
---@param range Range
189+
---@return Range6
190+
function M.add_bytes(source, range)
191+
if type(range) == 'table' and #range == 6 then
192+
return range --[[@as Range6]]
193+
end
194+
195+
local start_row, start_col, end_row, end_col = M.unpack4(range)
196+
-- TODO(vigoux): proper byte computation here, and account for EOL ?
197+
local start_byte = get_offset(source, start_row) + start_col
198+
local end_byte = get_offset(source, end_row) + end_col
199+
200+
return { start_row, start_col, start_byte, end_row, end_col, end_byte }
201+
end
202+
203+
return M

lua/nvim-treesitter-textobjects/shared.lua

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
local ts = vim.treesitter
2-
local add_bytes = require('vim.treesitter._range').add_bytes
2+
local ts_range = require('nvim-treesitter-textobjects._range')
33

44
-- lookup table for parserless queries
55
local lang_to_parser = { ecma = 'javascript', jsx = 'javascript' }
@@ -75,7 +75,7 @@ local get_query_matches = memoize(function(bufnr, query_group, root, root_lang)
7575
if query_name ~= nil then
7676
local path = vim.split(query_name, '%.')
7777
if metadata[id] and metadata[id].range then
78-
insert_to_path(prepared_match, path, add_bytes(bufnr, metadata[id].range))
78+
insert_to_path(prepared_match, path, ts_range.add_bytes(bufnr, metadata[id].range))
7979
else
8080
local srow, scol, sbyte, erow, ecol, ebyte = nodes[1]:range(true)
8181
if #nodes > 1 then
@@ -213,37 +213,11 @@ end
213213

214214
-- TODO: replace with `vim.Range:has(vim.Pos)` when we drop support for nvim 0.11
215215
---@param range Range4
216-
---@param row integer
216+
---@param line integer
217217
---@param col integer
218218
---@return boolean
219-
local function is_in_range(range, row, col)
220-
local start_row, start_col, end_row, end_col = unpack(range) ---@type integer, integer, integer, integer
221-
end_col = end_col - 1
222-
223-
local is_in_rows = start_row <= row and end_row >= row
224-
local is_after_start_col_if_needed = true
225-
if start_row == row then
226-
is_after_start_col_if_needed = col >= start_col
227-
end
228-
local is_before_end_col_if_needed = true
229-
if end_row == row then
230-
is_before_end_col_if_needed = col <= end_col
231-
end
232-
return is_in_rows and is_after_start_col_if_needed and is_before_end_col_if_needed
233-
end
234-
235-
-- TODO: replace with `vim.Range:has(vim.Range)` when we drop support for 0.11
236-
---@param outer Range4
237-
---@param inner Range4
238-
---@return boolean
239-
local function contains(outer, inner)
240-
local start_row_o, start_col_o, end_row_o, end_col_o = unpack(outer) ---@type integer, integer, integer, integer
241-
local start_row_i, start_col_i, end_row_i, end_col_i = unpack(inner) ---@type integer, integer, integer, integer
242-
243-
return start_row_o <= start_row_i
244-
and start_col_o <= start_col_i
245-
and end_row_o >= end_row_i
246-
and end_col_o >= end_col_i
219+
local function is_in_range(range, line, col)
220+
return ts_range.contains(range, { line, col, line, col + 1 })
247221
end
248222

249223
---@param range Range6
@@ -385,7 +359,7 @@ function M.textobject_at_point(query_string, query_group, bufnr, pos, opts)
385359

386360
local ranges_within_outer = {}
387361
for _, range in ipairs(ranges) do
388-
if contains(M.torange4(range_outer), M.torange4(range)) then
362+
if ts_range.contains(M.torange4(range_outer), M.torange4(range)) then
389363
table.insert(ranges_within_outer, range)
390364
end
391365
end

0 commit comments

Comments
 (0)