Skip to content

Commit 3dde77a

Browse files
committed
Add cancel! for cancelling tasks
1 parent f3845ab commit 3dde77a

File tree

5 files changed

+132
-4
lines changed

5 files changed

+132
-4
lines changed

src/Dagger.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ include("thunk.jl")
5252
include("submission.jl")
5353
include("chunks.jl")
5454
include("memory-spaces.jl")
55+
include("cancellation.jl")
5556

5657
# Task scheduling
5758
include("compute.jl")

src/cancellation.jl

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
function cancel!(tid::Union{Int,Nothing}=nothing;
2+
sch_uid::Union{UInt64,Nothing}=nothing,
3+
force::Bool=false, halt_sch::Bool=false)
4+
remotecall_fetch(1, tid, sch_uid, force, halt_sch) do tid, sch_uid, force, halt_sch
5+
state = Sch.EAGER_STATE[]
6+
@lock state.lock _cancel!(state, tid, sch_uid, force, halt_sch)
7+
end
8+
end
9+
function _cancel!(state, tid, sch_uid, force, halt_sch)
10+
@assert islocked(state.lock)
11+
12+
# Get the scheduler uid
13+
if sch_uid === nothing
14+
sch_uid = state.uid
15+
end
16+
17+
# Cancel ready tasks
18+
for task in state.ready
19+
tid !== nothing && task.id == tid && continue
20+
@dagdebug tid :cancel "Cancelling ready task"
21+
state.cache[task] = InterruptException()
22+
state.errored[task] = true
23+
Sch.set_failed!(state, task)
24+
end
25+
empty!(state.ready)
26+
27+
# Cancel waiting tasks
28+
for task in keys(state.waiting)
29+
tid !== nothing && task.id == tid && continue
30+
@dagdebug tid :cancel "Cancelling waiting task"
31+
state.cache[task] = InterruptException()
32+
state.errored[task] = true
33+
Sch.set_failed!(state, task)
34+
end
35+
empty!(state.waiting)
36+
37+
# Cancel running tasks at the processor level
38+
wids = unique(map(root_worker_id, values(state.running_on)))
39+
for wid in wids
40+
remotecall_fetch(wid, tid, sch_uid, force) do _tid, sch_uid, force
41+
Dagger.Sch.proc_states(sch_uid) do states
42+
for (proc, state) in states
43+
istate = state.state
44+
any_cancelled = false
45+
@lock istate.queue begin
46+
for (tid, task) in istate.tasks
47+
_tid !== nothing && tid == _tid && continue
48+
task_spec = istate.task_specs[tid]
49+
Tf = task_spec[6]
50+
Tf === typeof(Sch.eager_thunk) && continue
51+
istaskdone(task) && continue
52+
any_cancelled = true
53+
@dagdebug tid :cancel "Cancelling running task ($Tf)"
54+
if force
55+
@dagdebug tid :cancel "Interrupting running task ($Tf)"
56+
Threads.@spawn Base.throwto(task, InterruptException())
57+
else
58+
# Tell the processor to just drop this task
59+
task_occupancy = task_spec[4]
60+
time_util = task_spec[2]
61+
istate.proc_occupancy[] -= task_occupancy
62+
istate.time_pressure[] -= time_util
63+
push!(istate.cancelled, tid)
64+
to_proc = istate.proc
65+
put!(istate.return_queue, (myid(), to_proc, tid, (InterruptException(), nothing)))
66+
end
67+
end
68+
end
69+
if any_cancelled
70+
notify(istate.reschedule)
71+
end
72+
end
73+
end
74+
return
75+
end
76+
end
77+
78+
if halt_sch
79+
unlock(state.lock)
80+
81+
# Give tasks a moment to be processed
82+
sleep(0.5)
83+
84+
# Halt the scheduler
85+
@dagdebug nothing :cancel "Halting the scheduler"
86+
notify(state.halt)
87+
put!(state.chan, (1, nothing, nothing, (Sch.SchedulerHaltedException(), nothing)))
88+
89+
# Wait for the scheduler to halt
90+
@dagdebug nothing :cancel "Waiting for scheduler to halt"
91+
while Sch.EAGER_INIT[]
92+
sleep(0.1)
93+
end
94+
@dagdebug nothing :cancel "Scheduler halted"
95+
96+
lock(state.lock)
97+
end
98+
99+
return
100+
end

src/sch/Sch.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,8 @@ function scheduler_exit(ctx, state::ComputeState, options)
639639
lock(ctx.proc_notify) do
640640
notify(ctx.proc_notify)
641641
end
642+
643+
@dagdebug nothing :global "Tore down scheduler" uid=state.uid
642644
end
643645

644646
function procs_to_use(ctx, options=ctx.options)
@@ -1147,11 +1149,14 @@ Base.hash(key::TaskSpecKey, h::UInt) = hash(key.task_id, hash(TaskSpecKey, h))
11471149
struct ProcessorInternalState
11481150
ctx::Context
11491151
proc::Processor
1152+
return_queue::RemoteChannel
11501153
queue::LockedObject{PriorityQueue{TaskSpecKey, UInt32, Base.Order.ForwardOrdering}}
11511154
reschedule::Doorbell
11521155
tasks::Dict{Int,Task}
1156+
task_specs::Dict{Int,Vector{Any}}
11531157
proc_occupancy::Base.RefValue{UInt32}
11541158
time_pressure::Base.RefValue{UInt64}
1159+
cancelled::Set{Int}
11551160
done::Base.RefValue{Bool}
11561161
end
11571162
struct ProcessorState
@@ -1300,6 +1305,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
13001305

13011306
# Execute the task and return its result
13021307
t = @task begin
1308+
was_cancelled = false
13031309
result = try
13041310
do_task(to_proc, task)
13051311
catch err
@@ -1308,11 +1314,23 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
13081314
finally
13091315
lock(istate.queue) do _
13101316
delete!(tasks, thunk_id)
1311-
proc_occupancy[] -= task_occupancy
1312-
time_pressure[] -= time_util
1317+
delete!(istate.task_specs, thunk_id)
1318+
if !(thunk_id in istate.cancelled)
1319+
proc_occupancy[] -= task_occupancy
1320+
time_pressure[] -= time_util
1321+
else
1322+
# Task was cancelled, so occupancy and pressure are
1323+
# already reduced
1324+
pop!(istate.cancelled, thunk_id)
1325+
was_cancelled = true
1326+
end
13131327
end
13141328
notify(istate.reschedule)
13151329
end
1330+
if was_cancelled
1331+
# A result was already posted to the return queue
1332+
return
1333+
end
13161334
try
13171335
put!(return_queue, (myid(), to_proc, thunk_id, result))
13181336
catch err
@@ -1331,6 +1349,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
13311349
t.sticky = false
13321350
end
13331351
tasks[thunk_id] = errormonitor_tracked("thunk $thunk_id", schedule(t))
1352+
istate.task_specs[thunk_id] = task
13341353
proc_occupancy[] += task_occupancy
13351354
time_pressure[] += time_util
13361355
end
@@ -1363,10 +1382,12 @@ function do_tasks(to_proc, return_queue, tasks)
13631382
queue = PriorityQueue{TaskSpecKey, UInt32}()
13641383
queue_locked = LockedObject(queue)
13651384
reschedule = Doorbell()
1366-
istate = ProcessorInternalState(ctx, to_proc,
1385+
istate = ProcessorInternalState(ctx, to_proc, return_queue,
13671386
queue_locked, reschedule,
13681387
Dict{Int,Task}(),
1388+
Dict{Int,Vector{Any}}(),
13691389
Ref(UInt32(0)), Ref(UInt64(0)),
1390+
Set{Int}(),
13701391
Ref(false))
13711392
runner = start_processor_runner!(istate, uid, return_queue)
13721393
@static if VERSION < v"1.9"

src/threadproc.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @n
2626
fetch(task)
2727
return result[]
2828
catch err
29+
if err isa InterruptException
30+
if !istaskdone(task)
31+
# Propagate cancellation signal
32+
Threads.@spawn Base.throwto(task, InterruptException())
33+
end
34+
end
2935
err, frames = Base.current_exceptions(task)[1]
3036
rethrow(CapturedException(err, frames))
3137
end

src/utils/dagdebug.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ function task_id end
33

44
const DAGDEBUG_CATEGORIES = Symbol[:global, :submit, :schedule, :scope,
55
:take, :execute, :move, :processor,
6-
:stream]
6+
:stream, :cancel]
77
macro dagdebug(thunk, category, msg, args...)
88
cat_sym = category.value
99
@gensym id

0 commit comments

Comments
 (0)