Skip to content

Commit cbe64f8

Browse files
committed
DTask: Add waitany and waitall helpers
1 parent a765cbe commit cbe64f8

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed

src/dtask.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,32 @@ function Base.fetch(t::DTask; raw=false)
8585
end
8686
return fetch(t.future; raw)
8787
end
88+
function waitany(tasks::Vector{DTask})
89+
if isempty(tasks)
90+
return
91+
end
92+
cond = Threads.Condition()
93+
for task in tasks
94+
Sch.errormonitor_tracked("waitany listener", Threads.@spawn begin
95+
wait(task)
96+
@lock cond notify(cond)
97+
end)
98+
end
99+
@lock cond wait(cond)
100+
return
101+
end
102+
function waitall(tasks::Vector{DTask})
103+
if isempty(tasks)
104+
return
105+
end
106+
@sync for task in tasks
107+
Threads.@spawn begin
108+
wait(task)
109+
@lock cond notify(cond)
110+
end
111+
end
112+
return
113+
end
88114
function Base.show(io::IO, t::DTask)
89115
status = if istaskstarted(t)
90116
isready(t) ? "finished" : "running"

src/utils/tasks.jl

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,115 @@ function set_task_tid!(task::Task, tid::Integer)
1818
end
1919
@assert Threads.threadid(task) == tid "jl_set_task_tid failed!"
2020
end
21+
22+
if isdefined(Base, :waitany)
23+
import Base: waitany, waitall
24+
else
25+
# Vendored from Base
26+
# License is MIT
27+
waitany(tasks; throw=true) = _wait_multiple(tasks, throw)
28+
waitall(tasks; failfast=true, throw=true) = _wait_multiple(tasks, throw, true, failfast)
29+
function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false)
30+
tasks = Task[]
31+
32+
for t in waiting_tasks
33+
t isa Task || error("Expected an iterator of `Task` object")
34+
push!(tasks, t)
35+
end
36+
37+
if (all && !failfast) || length(tasks) <= 1
38+
exception = false
39+
# Force everything to finish synchronously for the case of waitall
40+
# with failfast=false
41+
for t in tasks
42+
_wait(t)
43+
exception |= istaskfailed(t)
44+
end
45+
if exception && throwexc
46+
exceptions = [TaskFailedException(t) for t in tasks if istaskfailed(t)]
47+
throw(CompositeException(exceptions))
48+
else
49+
return tasks, Task[]
50+
end
51+
end
52+
53+
exception = false
54+
nremaining::Int = length(tasks)
55+
done_mask = falses(nremaining)
56+
for (i, t) in enumerate(tasks)
57+
if istaskdone(t)
58+
done_mask[i] = true
59+
exception |= istaskfailed(t)
60+
nremaining -= 1
61+
else
62+
done_mask[i] = false
63+
end
64+
end
65+
66+
if nremaining == 0
67+
return tasks, Task[]
68+
elseif any(done_mask) && (!all || (failfast && exception))
69+
if throwexc && (!all || failfast) && exception
70+
exceptions = [TaskFailedException(t) for t in tasks[done_mask] if istaskfailed(t)]
71+
throw(CompositeException(exceptions))
72+
else
73+
return tasks[done_mask], tasks[.~done_mask]
74+
end
75+
end
76+
77+
chan = Channel{Int}(Inf)
78+
sentinel = current_task()
79+
waiter_tasks = fill(sentinel, length(tasks))
80+
81+
for (i, done) in enumerate(done_mask)
82+
done && continue
83+
t = tasks[i]
84+
if istaskdone(t)
85+
done_mask[i] = true
86+
exception |= istaskfailed(t)
87+
nremaining -= 1
88+
exception && failfast && break
89+
else
90+
waiter = @task put!(chan, i)
91+
waiter.sticky = false
92+
_wait2(t, waiter)
93+
waiter_tasks[i] = waiter
94+
end
95+
end
96+
97+
while nremaining > 0
98+
i = take!(chan)
99+
t = tasks[i]
100+
waiter_tasks[i] = sentinel
101+
done_mask[i] = true
102+
exception |= istaskfailed(t)
103+
nremaining -= 1
104+
105+
# stop early if requested, unless there is something immediately
106+
# ready to consume from the channel (using a race-y check)
107+
if (!all || (failfast && exception)) && !isready(chan)
108+
break
109+
end
110+
end
111+
112+
close(chan)
113+
114+
if nremaining == 0
115+
return tasks, Task[]
116+
else
117+
remaining_mask = .~done_mask
118+
for i in findall(remaining_mask)
119+
waiter = waiter_tasks[i]
120+
donenotify = tasks[i].donenotify::ThreadSynchronizer
121+
@lock donenotify Base.list_deletefirst!(donenotify.waitq, waiter)
122+
end
123+
done_tasks = tasks[done_mask]
124+
if throwexc && exception
125+
exceptions = [TaskFailedException(t) for t in done_tasks if istaskfailed(t)]
126+
throw(CompositeException(exceptions))
127+
else
128+
return done_tasks, tasks[remaining_mask]
129+
end
130+
end
131+
end
132+
end

0 commit comments

Comments
 (0)