Skip to content

Commit 2479971

Browse files
Inline math highlighting (Continued from PR #426) (#428)
Co-authored-by: Kristijan Husak <[email protected]>
1 parent 0f4eb33 commit 2479971

File tree

2 files changed

+194
-46
lines changed

2 files changed

+194
-46
lines changed

lua/orgmode/colors/markup_highlighter.lua

Lines changed: 174 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,64 @@ local config = require('orgmode.config')
22
local ts_utils = require('nvim-treesitter.ts_utils')
33
local query = nil
44

5-
local valid_pre_marker_chars = { ' ', '(', '-', "'", '"', '{' }
6-
local valid_post_marker_chars = { ' ', ')', '-', '}', '"', "'", ':', ';', '!', '\\', '[', ',', '.', '?' }
5+
local valid_pre_marker_chars = { ' ', '(', '-', "'", '"', '{', '*', '/', '_', '+' }
6+
local valid_post_marker_chars =
7+
{ ' ', ')', '-', '}', '"', "'", ':', ';', '!', '\\', '[', ',', '.', '?', '*', '/', '_', '+' }
78

89
local markers = {
910
['*'] = {
1011
hl_name = 'org_bold',
1112
hl_cmd = 'hi def org_bold term=bold cterm=bold gui=bold',
13+
nestable = true,
14+
type = 'text',
1215
},
1316
['/'] = {
1417
hl_name = 'org_italic',
1518
hl_cmd = 'hi def org_italic term=italic cterm=italic gui=italic',
19+
nestable = true,
20+
type = 'text',
1621
},
1722
['_'] = {
1823
hl_name = 'org_underline',
1924
hl_cmd = 'hi def org_underline term=underline cterm=underline gui=underline',
25+
nestable = true,
26+
type = 'text',
2027
},
2128
['+'] = {
2229
hl_name = 'org_strikethrough',
2330
hl_cmd = 'hi def org_strikethrough term=strikethrough cterm=strikethrough gui=strikethrough',
31+
nestable = true,
32+
type = 'text',
2433
},
2534
['~'] = {
2635
hl_name = 'org_code',
2736
hl_cmd = 'hi def link org_code String',
37+
nestable = false,
38+
type = 'text',
2839
},
2940
['='] = {
3041
hl_name = 'org_verbatim',
3142
hl_cmd = 'hi def link org_verbatim String',
43+
nestable = false,
44+
type = 'text',
45+
},
46+
['\\('] = {
47+
hl_name = 'org_latex',
48+
hl_cmd = 'hi def link org_latex OrgTSLatex',
49+
nestable = false,
50+
type = 'latex',
51+
},
52+
['\\{'] = {
53+
hl_name = 'org_latex',
54+
hl_cmd = 'hi def link org_latex OrgTSLatex',
55+
nestable = false,
56+
type = 'latex',
57+
},
58+
['\\s'] = {
59+
hl_name = 'org_latex',
60+
hl_cmd = 'hi def link org_latex OrgTSLatex',
61+
nestable = false,
62+
type = 'latex',
3263
},
3364
}
3465

@@ -71,23 +102,18 @@ local get_tree = ts_utils.memoize_by_buf_tick(function(bufnr)
71102
return tree[1]:root()
72103
end)
73104

74-
local function get_predicate_nodes(match)
105+
local function get_predicate_nodes(match, n)
106+
local total = n or 2
75107
local counter = 1
76-
local start_node = nil
77-
local end_node = nil
108+
local nodes = {}
78109
for i, node in pairs(match) do
79-
if counter == 1 then
80-
start_node = node
81-
end
82-
if counter == 2 then
83-
end_node = node
84-
end
110+
nodes[counter] = node
85111
counter = counter + 1
112+
if counter > total then
113+
break
114+
end
86115
end
87-
if not start_node or not end_node then
88-
return false
89-
end
90-
return start_node, end_node
116+
return unpack(nodes)
91117
end
92118

93119
local function is_valid_markup_range(match, _, source, _)
@@ -96,9 +122,11 @@ local function is_valid_markup_range(match, _, source, _)
96122
return
97123
end
98124

99-
-- Ignore conflicts with hyperlink
100-
if start_node:type() == '[' or end_node:type() == ']' then
101-
return true
125+
-- Ignore conflicts with hyperlink or math
126+
for _, char in ipairs({ '[', '\\' }) do
127+
if start_node:type() == char or end_node:type() == char then
128+
return true
129+
end
102130
end
103131

104132
local start_line = start_node:range()
@@ -146,6 +174,48 @@ local function is_valid_hyperlink_range(match, _, source, _)
146174
return is_valid_start and is_valid_end
147175
end
148176

177+
local function is_valid_latex_range(match, _, source, _)
178+
local start_node_left, start_node_right, end_node = get_predicate_nodes(match, 3)
179+
-- Ignore conflicts with markup
180+
if start_node_left:type() ~= '\\' then
181+
return true
182+
end
183+
if not start_node_right or not end_node then
184+
return
185+
end
186+
187+
local start_line = start_node_left:range()
188+
local end_line = start_node_left:range()
189+
190+
if start_line ~= end_line then
191+
return false
192+
end
193+
194+
local _, start_left_col_end = start_node_left:end_()
195+
local _, start_right_col_end = start_node_right:end_()
196+
local start_text = get_node_text(start_node_left, source, 0, start_right_col_end - start_left_col_end)
197+
198+
if start_text == '\\(' then
199+
local end_text = get_node_text(end_node, source, -1, 0)
200+
if end_text == '\\)' then
201+
return true
202+
end
203+
else
204+
-- we have to deal with two cases here either \foo{bar} or \bar
205+
local char_after_start = get_node_text(start_node_right, source, 0, 1):sub(-1)
206+
local end_text = get_node_text(end_node, source, 0, 0)
207+
-- if \foo{bar}
208+
if char_after_start == '{' and end_text == '}' then
209+
return true
210+
end
211+
-- elseif \bar
212+
if not start_text:sub(2):match('%A') and end_text ~= '}' then
213+
return true
214+
end
215+
end
216+
return false
217+
end
218+
149219
local function load_deps()
150220
-- Already defined
151221
if query then
@@ -154,6 +224,7 @@ local function load_deps()
154224
query = vim.treesitter.get_query('org', 'markup')
155225
vim.treesitter.query.add_predicate('org-is-valid-markup-range?', is_valid_markup_range)
156226
vim.treesitter.query.add_predicate('org-is-valid-hyperlink-range?', is_valid_hyperlink_range)
227+
vim.treesitter.query.add_predicate('org-is-valid-latex-range?', is_valid_latex_range)
157228
end
158229

159230
---@param bufnr? number
@@ -173,15 +244,19 @@ local function get_matches(bufnr, first_line, last_line)
173244
for _, match, _ in query:iter_matches(root, bufnr, first_line, last_line) do
174245
for _, node in pairs(match) do
175246
local char = node:type()
176-
local range = ts_utils.node_to_lsp_range(node)
177-
local linenr = tostring(range.start.line)
178-
taken_locations[linenr] = taken_locations[linenr] or {}
179-
if not taken_locations[linenr][range.start.character] then
180-
table.insert(ranges, {
181-
type = char,
182-
range = range,
183-
})
184-
taken_locations[linenr][range.start.character] = true
247+
-- saves unnecessary parsing, since \\ is not used below
248+
if char ~= '\\' then
249+
local range = ts_utils.node_to_lsp_range(node)
250+
local linenr = tostring(range.start.line)
251+
taken_locations[linenr] = taken_locations[linenr] or {}
252+
if not taken_locations[linenr][range.start.character] then
253+
table.insert(ranges, {
254+
type = char,
255+
range = range,
256+
node = node,
257+
})
258+
taken_locations[linenr][range.start.character] = true
259+
end
185260
end
186261
end
187262
end
@@ -197,30 +272,74 @@ local function get_matches(bufnr, first_line, last_line)
197272
local seek_link = {}
198273
local result = {}
199274
local link_result = {}
275+
local latex_result = {}
276+
277+
local nested = {}
278+
local can_nest = true
279+
280+
local type_map = {
281+
['('] = '\\(',
282+
[')'] = '\\(',
283+
['}'] = '\\{',
284+
}
200285

201286
for _, item in ipairs(ranges) do
287+
if item.type == '(' then
288+
item.range.start.character = item.range.start.character - 1
289+
elseif item.type == 'str' then
290+
item.range.start.character = item.range.start.character - 1
291+
local char = get_node_text(item.node, bufnr, 0, 1):sub(-1)
292+
if char == '{' then
293+
item.type = '\\{'
294+
else
295+
item.type = '\\s'
296+
end
297+
end
298+
299+
item.type = type_map[item.type] or item.type
300+
202301
if markers[item.type] then
203302
if seek[item.type] then
204303
local from = seek[item.type]
205-
table.insert(result, {
206-
type = item.type,
207-
from = from.range,
208-
to = item.range,
209-
})
210-
211-
seek[item.type] = nil
212-
213-
for t, pos in pairs(seek) do
214-
if
215-
pos.range.start.line == from.range.start.line
216-
and pos.range.start.character > from.range['end'].character
217-
and pos.range.start.character < item.range.start.character
218-
then
219-
seek[t] = nil
304+
if nested[#nested] == nil or nested[#nested] == from.type then
305+
local target_result = result
306+
if markers[item.type].type == 'latex' then
307+
target_result = latex_result
308+
end
309+
310+
table.insert(target_result, {
311+
type = item.type,
312+
from = from.range,
313+
to = item.range,
314+
})
315+
316+
seek[item.type] = nil
317+
nested[#nested] = nil
318+
can_nest = true
319+
320+
for t, pos in pairs(seek) do
321+
if
322+
pos.range.start.line == from.range.start.line
323+
and pos.range.start.character > from.range['end'].character
324+
and pos.range.start.character < item.range.start.character
325+
then
326+
seek[t] = nil
327+
end
220328
end
221329
end
222-
else
223-
seek[item.type] = item
330+
elseif can_nest then
331+
-- escaped strings have no pairs, their markup info is self-contained
332+
if item.type == '\\s' then
333+
table.insert(result, {
334+
type = item.type,
335+
from = item.range,
336+
to = item.range,
337+
})
338+
else
339+
seek[item.type] = item
340+
nested[#nested + 1] = item.type
341+
can_nest = markers[item.type].nestable
342+
end
224343
end
225344
end
226345

@@ -237,7 +356,7 @@ local function get_matches(bufnr, first_line, last_line)
237356
end
238357
end
239358

240-
return result, link_result
359+
return result, link_result, latex_result
241360
end
242361

243362
local function apply(namespace, bufnr, _, first_line, last_line, _)
@@ -252,7 +371,7 @@ local function apply(namespace, bufnr, _, first_line, last_line, _)
252371
if #visible_lines == 0 then
253372
return
254373
end
255-
local ranges, link_ranges = get_matches(bufnr, visible_lines[1], visible_lines[#visible_lines])
374+
local ranges, link_ranges, latex_ranges = get_matches(bufnr, visible_lines[1], visible_lines[#visible_lines])
256375
local hide_markers = config.org_hide_emphasis_markers
257376

258377
for _, range in ipairs(ranges) do
@@ -301,6 +420,15 @@ local function apply(namespace, bufnr, _, first_line, last_line, _)
301420
conceal = '',
302421
})
303422
end
423+
424+
for _, latex_range in ipairs(latex_ranges) do
425+
vim.api.nvim_buf_set_extmark(bufnr, namespace, latex_range.from.start.line, latex_range.from.start.character, {
426+
ephemeral = true,
427+
end_col = latex_range.to['end'].character,
428+
hl_group = markers[latex_range.type].hl_name,
429+
priority = 110 + latex_range.from.start.character,
430+
})
431+
end
304432
end
305433

306434
local function setup()

queries/org/markup.scm

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
(expr "+" @strikethrough.start "+" @strikethrough.end (#org-is-valid-markup-range? @strikethrough.start @strikethrough.end))
1414
((expr "[" @hyperlink.start) (expr "]" @hyperlink.end) (#org-is-valid-hyperlink-range? @hyperlink.start @hyperlink.end))
1515
(expr "[" @hyperlink.start "]" @hyperlink.end (#org-is-valid-hyperlink-range? @hyperlink.start @hyperlink.end))
16+
((expr ("\\" @text.math.start.left "(" @text.math.start.right)) (expr ("\\" ")" @text.math.end)) (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
17+
(expr "\\" @text.math.start.left "(" @text.math.start.right "\\" ")" @text.math.end (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
18+
((expr ("\\" @text.math.start.left ("str")+ @text.math.start.right "{")) (expr "}" @text.math.end) (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
19+
(expr "\\" @text.math.start.left ("str")+ @text.math.start.right "{" "}" @text.math.end (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
20+
(expr "\\" @text.math.start.left ("str")+ @text.math.start.right @text.math.end (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
1621
])
1722

1823
(item [
@@ -30,6 +35,11 @@
3035
(expr "+" @strikethrough.start "+" @strikethrough.end (#org-is-valid-markup-range? @strikethrough.start @strikethrough.end))
3136
((expr "[" @hyperlink.start) (expr "]" @hyperlink.end) (#org-is-valid-hyperlink-range? @hyperlink.start @hyperlink.end))
3237
(expr "[" @hyperlink.start "]" @hyperlink.end (#org-is-valid-hyperlink-range? @hyperlink.start @hyperlink.end))
38+
((expr ("\\" @text.math.start.left "(" @text.math.start.right)) (expr ("\\" ")" @text.math.end)) (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
39+
(expr "\\" @text.math.start.left "(" @text.math.start.right "\\" ")" @text.math.end (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
40+
((expr ("\\" @text.math.start.left ("str")+ @text.math.start.right "{")) (expr "}" @text.math.end) (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
41+
(expr "\\" @text.math.start.left ("str")+ @text.math.start.right "{" "}" @text.math.end (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
42+
(expr "\\" @text.math.start.left ("str")+ @text.math.start.right @text.math.end (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
3343
])
3444

3545
(cell (contents [
@@ -47,6 +57,11 @@
4757
(expr "+" @strikethrough.start "+" @strikethrough.end (#org-is-valid-markup-range? @strikethrough.start @strikethrough.end))
4858
((expr "[" @hyperlink.start) (expr "]" @hyperlink.end) (#org-is-valid-hyperlink-range? @hyperlink.start @hyperlink.end))
4959
(expr "[" @hyperlink.start "]" @hyperlink.end (#org-is-valid-hyperlink-range? @hyperlink.start @hyperlink.end))
60+
((expr ("\\" @text.math.start.left "(" @text.math.start.right)) (expr ("\\" ")" @text.math.end)) (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
61+
(expr "\\" @text.math.start.left "(" @text.math.start.right "\\" ")" @text.math.end (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
62+
((expr ("\\" @text.math.start.left ("str")+ @text.math.start.right "{")) (expr "}" @text.math.end) (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
63+
(expr "\\" @text.math.start.left ("str")+ @text.math.start.right "{" "}" @text.math.end (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
64+
(expr "\\" @text.math.start.left ("str")+ @text.math.start.right @text.math.end (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
5065
]))
5166

5267
(drawer (contents [
@@ -64,4 +79,9 @@
6479
(expr "+" @strikethrough.start "+" @strikethrough.end (#org-is-valid-markup-range? @strikethrough.start @strikethrough.end))
6580
((expr "[" @hyperlink.start) (expr "]" @hyperlink.end) (#org-is-valid-hyperlink-range? @hyperlink.start @hyperlink.end))
6681
(expr "[" @hyperlink.start "]" @hyperlink.end (#org-is-valid-hyperlink-range? @hyperlink.start @hyperlink.end))
82+
((expr ("\\" @text.math.start.left "(" @text.math.start.right)) (expr ("\\" ")" @text.math.end)) (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
83+
(expr "\\" @text.math.start.left "(" @text.math.start.right "\\" ")" @text.math.end (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
84+
((expr ("\\" @text.math.start.left ("str")+ @text.math.start.right "{")) (expr "}" @text.math.end) (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
85+
(expr "\\" @text.math.start.left ("str")+ @text.math.start.right "{" "}" @text.math.end (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
86+
(expr "\\" @text.math.start.left ("str")+ @text.math.start.right @text.math.end (#org-is-valid-latex-range? @text.math.start.left @text.math.start.right @text.math.end))
6787
]))

0 commit comments

Comments
 (0)