Skip to content

Commit 92ff217

Browse files
authored
feat: migrate to _async (#228)
* feat: migrate to _async
1 parent 8c1f379 commit 92ff217

File tree

7 files changed

+349
-220
lines changed

7 files changed

+349
-220
lines changed

Makefile

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
LUAROCKS_PATH_CMD = luarocks path --no-bin --lua-version 5.1
2+
BUSTED = eval $$(luarocks path --no-bin --lua-version 5.1) && busted --lua nlua
3+
TEST_DIR = spec
4+
5+
.PHONY: test
6+
test:
7+
@echo "Running tests..."
8+
@if [ -n "$(file)" ]; then \
9+
$(BUSTED) $(file); \
10+
else \
11+
$(BUSTED) $(TEST_DIR); \
12+
fi

lua/guard/_async.lua

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
local M = {}
2+
3+
local max_timeout = 30000
4+
local copcall = package.loaded.jit and pcall or require('coxpcall').pcall
5+
6+
--- @param thread thread
7+
--- @param on_finish fun(err: string?, ...:any)
8+
--- @param ... any
9+
local function resume(thread, on_finish, ...)
10+
--- @type {n: integer, [1]:boolean, [2]:string|function}
11+
local ret = vim.F.pack_len(coroutine.resume(thread, ...))
12+
local stat = ret[1]
13+
14+
if not stat then
15+
-- Coroutine had error
16+
on_finish(ret[2] --[[@as string]])
17+
elseif coroutine.status(thread) == 'dead' then
18+
-- Coroutine finished
19+
on_finish(nil, unpack(ret, 2, ret.n))
20+
else
21+
local fn = ret[2]
22+
--- @cast fn -string
23+
24+
--- @type boolean, string?
25+
local ok, err = copcall(fn, function(...)
26+
resume(thread, on_finish, ...)
27+
end)
28+
29+
if not ok then
30+
on_finish(err)
31+
end
32+
end
33+
end
34+
35+
--- @param func async fun(): ...:any
36+
--- @param on_finish? fun(err: string?, ...:any)
37+
function M.run(func, on_finish)
38+
local res --- @type {n:integer, [integer]:any}?
39+
resume(coroutine.create(func), function(err, ...)
40+
res = vim.F.pack_len(err, ...)
41+
if on_finish then
42+
on_finish(err, ...)
43+
end
44+
end)
45+
46+
return {
47+
--- @param timeout? integer
48+
--- @return any ... return values of `func`
49+
wait = function(_self, timeout)
50+
vim.wait(timeout or max_timeout, function()
51+
return res ~= nil
52+
end)
53+
assert(res, 'timeout')
54+
if res[1] then
55+
error(res[1])
56+
end
57+
return unpack(res, 2, res.n)
58+
end,
59+
}
60+
end
61+
62+
--- Asynchronous blocking wait
63+
--- @async
64+
--- @param argc integer
65+
--- @param fun function
66+
--- @param ... any func arguments
67+
--- @return any ...
68+
function M.await(argc, fun, ...)
69+
assert(coroutine.running(), 'Async.await() must be called from an async function')
70+
local args = vim.F.pack_len(...) --- @type {n:integer, [integer]:any}
71+
72+
--- @param callback fun(...:any)
73+
return coroutine.yield(function(callback)
74+
args[argc] = assert(callback)
75+
fun(unpack(args, 1, math.max(argc, args.n)))
76+
end)
77+
end
78+
79+
--- @async
80+
--- @param max_jobs integer
81+
--- @param funs (async fun())[]
82+
function M.join(max_jobs, funs)
83+
if #funs == 0 then
84+
return
85+
end
86+
87+
max_jobs = math.min(max_jobs, #funs)
88+
89+
--- @type (async fun())[]
90+
local remaining = { select(max_jobs + 1, unpack(funs)) }
91+
local to_go = #funs
92+
93+
M.await(1, function(on_finish)
94+
local function run_next()
95+
to_go = to_go - 1
96+
if to_go == 0 then
97+
on_finish()
98+
elseif #remaining > 0 then
99+
local next_fun = table.remove(remaining)
100+
M.run(next_fun, run_next)
101+
end
102+
end
103+
104+
for i = 1, max_jobs do
105+
M.run(funs[i], run_next)
106+
end
107+
end)
108+
end
109+
110+
return M

0 commit comments

Comments
 (0)