Skip to content

Commit ecd03f5

Browse files
authored
fix: range contains(), is_in_range() to use nvim equivalents (#845)
Changes most functions to accept generic `Range` type instead of only accepting a specific `Range4`, `Range6` type etc. and converting the type at the call site. Mainly changed `shared.torange4()` -> `_range.unpack4()` to achieve this.
1 parent f7122ac commit ecd03f5

File tree

7 files changed

+271
-71
lines changed

7 files changed

+271
-71
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
--- Copy of vim.treesitter._range
2+
--- TODO: replace with `vim.Range` when we drop support for 0.11
3+
4+
---@diagnostic disable: duplicate-doc-field
5+
---@diagnostic disable: duplicate-doc-alias
6+
7+
local api = vim.api
8+
9+
local M = {}
10+
11+
---@class Range2
12+
---@inlinedoc
13+
---@field [1] integer start row
14+
---@field [2] integer end row
15+
16+
---@class Range4
17+
---@inlinedoc
18+
---@field [1] integer start row
19+
---@field [2] integer start column
20+
---@field [3] integer end row
21+
---@field [4] integer end column
22+
23+
---@class Range6
24+
---@inlinedoc
25+
---@field [1] integer start row
26+
---@field [2] integer start column
27+
---@field [3] integer start bytes
28+
---@field [4] integer end row
29+
---@field [5] integer end column
30+
---@field [6] integer end bytes
31+
32+
---@alias Range Range2|Range4|Range6
33+
34+
---@param a_row integer
35+
---@param a_col integer
36+
---@param b_row integer
37+
---@param b_col integer
38+
---@return integer
39+
--- 1: a > b
40+
--- 0: a == b
41+
--- -1: a < b
42+
local function cmp_pos(a_row, a_col, b_row, b_col)
43+
if a_row == b_row then
44+
if a_col > b_col then
45+
return 1
46+
elseif a_col < b_col then
47+
return -1
48+
else
49+
return 0
50+
end
51+
elseif a_row > b_row then
52+
return 1
53+
end
54+
55+
return -1
56+
end
57+
58+
M.cmp_pos = {
59+
lt = function(...)
60+
return cmp_pos(...) == -1
61+
end,
62+
le = function(...)
63+
return cmp_pos(...) ~= 1
64+
end,
65+
gt = function(...)
66+
return cmp_pos(...) == 1
67+
end,
68+
ge = function(...)
69+
return cmp_pos(...) ~= -1
70+
end,
71+
eq = function(...)
72+
return cmp_pos(...) == 0
73+
end,
74+
ne = function(...)
75+
return cmp_pos(...) ~= 0
76+
end,
77+
}
78+
79+
setmetatable(M.cmp_pos, { __call = cmp_pos })
80+
81+
---Check if a variable is a valid range object
82+
---@param r any
83+
---@return boolean
84+
function M.validate(r)
85+
if type(r) ~= 'table' or #r ~= 6 and #r ~= 4 then
86+
return false
87+
end
88+
89+
for _, e in
90+
ipairs(r --[[@as any[] ]])
91+
do
92+
if type(e) ~= 'number' then
93+
return false
94+
end
95+
end
96+
97+
return true
98+
end
99+
100+
---@param r1 Range
101+
---@param r2 Range
102+
---@return boolean
103+
function M.intercepts(r1, r2)
104+
local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
105+
local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
106+
107+
-- r1 is above r2
108+
if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
109+
return false
110+
end
111+
112+
-- r1 is below r2
113+
if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then
114+
return false
115+
end
116+
117+
return true
118+
end
119+
120+
---@param r1 Range6
121+
---@param r2 Range6
122+
---@return Range6?
123+
function M.intersection(r1, r2)
124+
if not M.intercepts(r1, r2) then
125+
return nil
126+
end
127+
local rs = M.cmp_pos.le(r1[1], r1[2], r2[1], r2[2]) and r2 or r1
128+
local re = M.cmp_pos.ge(r1[4], r1[5], r2[4], r2[5]) and r2 or r1
129+
return { rs[1], rs[2], rs[3], re[4], re[5], re[6] }
130+
end
131+
132+
---@param r Range
133+
---@return integer, integer, integer, integer
134+
function M.unpack4(r)
135+
if #r == 2 then
136+
return r[1], 0, r[2], 0
137+
end
138+
local off_1 = #r == 6 and 1 or 0
139+
return r[1], r[2], r[3 + off_1], r[4 + off_1]
140+
end
141+
142+
---@param r Range6
143+
---@return integer, integer, integer, integer, integer, integer
144+
function M.unpack6(r)
145+
return r[1], r[2], r[3], r[4], r[5], r[6]
146+
end
147+
148+
---@param r1 Range
149+
---@param r2 Range
150+
---@return boolean whether r1 contains r2
151+
function M.contains(r1, r2)
152+
local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
153+
local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
154+
155+
-- start doesn't fit
156+
if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
157+
return false
158+
end
159+
160+
-- end doesn't fit
161+
if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then
162+
return false
163+
end
164+
165+
return true
166+
end
167+
168+
--- @param source integer|string
169+
--- @param index integer
170+
--- @return integer
171+
local function get_offset(source, index)
172+
if index == 0 then
173+
return 0
174+
end
175+
176+
if type(source) == 'number' then
177+
return api.nvim_buf_get_offset(source, index)
178+
end
179+
180+
local byte = 0
181+
local next_offset = source:gmatch('()\n')
182+
local line = 1
183+
while line <= index do
184+
byte = next_offset() --[[@as integer]]
185+
line = line + 1
186+
end
187+
188+
return byte
189+
end
190+
191+
---@param source integer|string
192+
---@param range Range
193+
---@return Range6
194+
function M.add_bytes(source, range)
195+
if type(range) == 'table' and #range == 6 then
196+
return range --[[@as Range6]]
197+
end
198+
199+
local start_row, start_col, end_row, end_col = M.unpack4(range)
200+
-- TODO(vigoux): proper byte computation here, and account for EOL ?
201+
local start_byte = get_offset(source, start_row) + start_col
202+
local end_byte = get_offset(source, end_row) + end_col
203+
204+
return { start_row, start_col, start_byte, end_row, end_col, end_byte }
205+
end
206+
207+
return M

lua/nvim-treesitter-textobjects/move.lua

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ local api = vim.api
33
local shared = require('nvim-treesitter-textobjects.shared')
44
local repeatable_move = require('nvim-treesitter-textobjects.repeatable_move')
55
local global_config = require('nvim-treesitter-textobjects.config')
6+
local ts_range = vim.treesitter._range or require('nvim-treesitter-textobjects._range')
67

7-
---@param range Range4?
8+
---@param range Range?
89
---@param goto_end boolean
910
---@param avoid_set_jump boolean
1011
local function goto_node(range, goto_end, avoid_set_jump)
@@ -16,7 +17,7 @@ local function goto_node(range, goto_end, avoid_set_jump)
1617
vim.cmd("normal! m'")
1718
end
1819
---@type integer, integer, integer, integer
19-
local start_row, start_col, end_row, end_col = unpack(range)
20+
local start_row, start_col, end_row, end_col = ts_range.unpack4(range)
2021

2122
-- Enter visual mode if we are in operator pending mode
2223
-- If we don't do this, it will miss the last character.
@@ -89,13 +90,13 @@ local function move(opts, query_strings, query_group)
8990
end
9091

9192
---@param start_ boolean
92-
---@param range Range6
93+
---@param range Range
9394
---@return boolean
9495
local function filter_function(start_, range)
9596
local row, col = unpack(api.nvim_win_get_cursor(winid)) --[[@as integer, integer]]
9697
row = row - 1 -- nvim_win_get_cursor is (1,0)-indexed
9798
---@type integer, integer, integer, integer, integer, integer
98-
local start_row, start_col, _, end_row, end_col, _ = unpack(range)
99+
local start_row, start_col, end_row, end_col = ts_range.unpack4(range)
99100

100101
if not start_ then
101102
if end_col == 0 then
@@ -146,7 +147,7 @@ local function move(opts, query_strings, query_group)
146147
end
147148
end
148149
end
149-
goto_node(best_range and shared.torange4(best_range), not best_start, not config.set_jumps)
150+
goto_node(best_range and best_range, not best_start, not config.set_jumps)
150151
end
151152
end
152153

lua/nvim-treesitter-textobjects/select.lua

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
local api = vim.api
22
local global_config = require('nvim-treesitter-textobjects.config')
33
local shared = require('nvim-treesitter-textobjects.shared')
4+
local ts_range = vim.treesitter._range or require('nvim-treesitter-textobjects._range')
45

5-
---@param range Range4
6+
---@param range Range
67
---@param selection_mode TSTextObjects.SelectionMode
78
local function update_selection(range, selection_mode)
89
---@type integer, integer, integer, integer
9-
local start_row, start_col, end_row, end_col = unpack(range)
10+
local start_row, start_col, end_row, end_col = ts_range.unpack4(range)
1011
selection_mode = selection_mode or 'v'
1112

1213
-- enter visual mode if normal or operator-pending (no) mode
@@ -105,11 +106,11 @@ local function previous_position(bufnr, row, col)
105106
end
106107

107108
---@param bufnr integer
108-
---@param range Range4
109+
---@param range Range
109110
---@param selection_mode string
110111
---@return Range4?
111112
local function include_surrounding_whitespace(bufnr, range, selection_mode)
112-
local start_row, start_col, end_row, end_col = unpack(range) ---@type integer, integer, integer, integer
113+
local start_row, start_col, end_row, end_col = ts_range.unpack4(range) ---@type integer, integer, integer, integer
113114
local extended = false
114115
local position = { end_row, end_col - 1 }
115116
local next = next_position(bufnr, unpack(position))
@@ -166,19 +167,19 @@ function M.select_textobject(query_string, query_group)
166167
{ lookahead = lookahead, lookbehind = lookbehind }
167168
)
168169
if range6 then
169-
local range4 = shared.torange4(range6)
170170
local selection_mode = M.detect_selection_mode(query_string)
171171
if
172172
function_or_value_to_value(surrounding_whitespace, {
173173
query_string = query_string,
174174
selection_mode = selection_mode,
175175
})
176176
then
177-
---@diagnostic disable-next-line: cast-local-type
178-
range4 = include_surrounding_whitespace(bufnr, range4, selection_mode)
179-
end
180-
if range4 then
181-
update_selection(range4, selection_mode)
177+
local range4 = include_surrounding_whitespace(bufnr, range6, selection_mode)
178+
if range4 then
179+
update_selection(range4, selection_mode)
180+
end
181+
else
182+
update_selection(range6, selection_mode)
182183
end
183184
end
184185
end

lua/nvim-treesitter-textobjects/shared.lua

Lines changed: 8 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
local ts = vim.treesitter
2-
local add_bytes = require('vim.treesitter._range').add_bytes
2+
local ts_range = ts._range or require('nvim-treesitter-textobjects._range')
33

44
-- lookup table for parserless queries
55
local lang_to_parser = { ecma = 'javascript', jsx = 'javascript' }
@@ -75,7 +75,7 @@ local get_query_matches = memoize(function(bufnr, query_group, root, root_lang)
7575
if query_name ~= nil then
7676
local path = vim.split(query_name, '%.')
7777
if metadata[id] and metadata[id].range then
78-
insert_to_path(prepared_match, path, add_bytes(bufnr, metadata[id].range))
78+
insert_to_path(prepared_match, path, ts_range.add_bytes(bufnr, metadata[id].range))
7979
else
8080
local srow, scol, sbyte, erow, ecol, ebyte = nodes[1]:range(true)
8181
if #nodes > 1 then
@@ -212,44 +212,12 @@ function M.find_best_range(bufnr, capture_string, query_group, filter_predicate,
212212
end
213213

214214
-- TODO: replace with `vim.Range:has(vim.Pos)` when we drop support for nvim 0.11
215-
---@param range Range4
216-
---@param row integer
215+
---@param range Range
216+
---@param line integer
217217
---@param col integer
218218
---@return boolean
219-
local function is_in_range(range, row, col)
220-
local start_row, start_col, end_row, end_col = unpack(range) ---@type integer, integer, integer, integer
221-
end_col = end_col - 1
222-
223-
local is_in_rows = start_row <= row and end_row >= row
224-
local is_after_start_col_if_needed = true
225-
if start_row == row then
226-
is_after_start_col_if_needed = col >= start_col
227-
end
228-
local is_before_end_col_if_needed = true
229-
if end_row == row then
230-
is_before_end_col_if_needed = col <= end_col
231-
end
232-
return is_in_rows and is_after_start_col_if_needed and is_before_end_col_if_needed
233-
end
234-
235-
-- TODO: replace with `vim.Range:has(vim.Range)` when we drop support for 0.11
236-
---@param outer Range4
237-
---@param inner Range4
238-
---@return boolean
239-
local function contains(outer, inner)
240-
local start_row_o, start_col_o, end_row_o, end_col_o = unpack(outer) ---@type integer, integer, integer, integer
241-
local start_row_i, start_col_i, end_row_i, end_col_i = unpack(inner) ---@type integer, integer, integer, integer
242-
243-
return start_row_o <= start_row_i
244-
and start_col_o <= start_col_i
245-
and end_row_o >= end_row_i
246-
and end_col_o >= end_col_i
247-
end
248-
249-
---@param range Range6
250-
---@return Range4
251-
function M.torange4(range)
252-
return { range[1], range[2], range[4], range[5] }
219+
local function is_in_range(range, line, col)
220+
return ts_range.contains(range, { line, col, line, col + 1 })
253221
end
254222

255223
--- Get the best `TSTextObjects.Range` at a given point
@@ -274,7 +242,7 @@ local function best_range_at_point(ranges, row, col, opts)
274242
local lookbehind_earliest_start ---@type integer
275243

276244
for _, range in pairs(ranges) do
277-
if range and is_in_range(M.torange4(range), row, col) then
245+
if range and is_in_range(range, row, col) then
278246
local length = range[6] - range[3]
279247
if not range_length or length < range_length then
280248
smallest_range = range
@@ -385,7 +353,7 @@ function M.textobject_at_point(query_string, query_group, bufnr, pos, opts)
385353

386354
local ranges_within_outer = {}
387355
for _, range in ipairs(ranges) do
388-
if contains(M.torange4(range_outer), M.torange4(range)) then
356+
if ts_range.contains(range_outer, range) then
389357
table.insert(ranges_within_outer, range)
390358
end
391359
end

0 commit comments

Comments
 (0)