Skip to content

Commit c07db4c

Browse files
committed
cancellation: Add cancel token support
1 parent ed2493c commit c07db4c

File tree

5 files changed

+75
-5
lines changed

5 files changed

+75
-5
lines changed

src/Dagger.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,16 @@ include("processor.jl")
5757
include("threadproc.jl")
5858
include("context.jl")
5959
include("utils/processors.jl")
60+
include("dtask.jl")
61+
include("cancellation.jl")
6062
include("task-tls.jl")
6163
include("scopes.jl")
6264
include("utils/scopes.jl")
63-
include("dtask.jl")
6465
include("queue.jl")
6566
include("thunk.jl")
6667
include("submission.jl")
6768
include("chunks.jl")
6869
include("memory-spaces.jl")
69-
include("cancellation.jl")
7070

7171
# Task scheduling
7272
include("compute.jl")

src/cancellation.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,38 @@
1+
# DTask-level cancellation
2+
3+
struct CancelToken
4+
cancelled::Base.RefValue{Bool}
5+
event::Base.Event
6+
end
7+
CancelToken() = CancelToken(Ref(false), Base.Event())
8+
function cancel!(token::CancelToken)
9+
token.cancelled[] = true
10+
notify(token.event)
11+
return
12+
end
13+
is_cancelled(token::CancelToken) = token.cancelled[]
14+
Base.wait(token::CancelToken) = wait(token.event)
15+
# TODO: Enable this for safety
16+
#Serialization.serialize(io::AbstractSerializer, ::CancelToken) =
17+
# throw(ConcurrencyViolationError("Cannot serialize a CancelToken"))
18+
19+
const DTASK_CANCEL_TOKEN = TaskLocalValue{Union{CancelToken,Nothing}}(()->nothing)
20+
21+
function clone_cancel_token_remote(orig_token::CancelToken, wid::Integer)
22+
remote_token = remotecall_fetch(wid) do
23+
return poolset(CancelToken())
24+
end
25+
errormonitor_tracked("remote cancel_token communicator", Threads.@spawn begin
26+
wait(orig_token)
27+
@dagdebug nothing :cancel "Cancelling remote token on worker $wid"
28+
MemPool.access_ref(remote_token) do remote_token
29+
cancel!(remote_token)
30+
end
31+
end)
32+
end
33+
34+
# Global-level cancellation
35+
136
"""
237
cancel!(task::DTask; force::Bool=false, halt_sch::Bool=false)
338
@@ -80,11 +115,11 @@ function _cancel!(state, tid, force, halt_sch)
80115
Tf === typeof(Sch.eager_thunk) && continue
81116
istaskdone(task) && continue
82117
any_cancelled = true
83-
@dagdebug tid :cancel "Cancelling running task ($Tf)"
84118
if force
85119
@dagdebug tid :cancel "Interrupting running task ($Tf)"
86120
Threads.@spawn Base.throwto(task, InterruptException())
87121
else
122+
@dagdebug tid :cancel "Cancelling running task ($Tf)"
88123
# Tell the processor to just drop this task
89124
task_occupancy = task_spec[4]
90125
time_util = task_spec[2]
@@ -93,6 +128,7 @@ function _cancel!(state, tid, force, halt_sch)
93128
push!(istate.cancelled, tid)
94129
to_proc = istate.proc
95130
put!(istate.return_queue, (myid(), to_proc, tid, (InterruptException(), nothing)))
131+
cancel!(istate.cancel_tokens[tid])
96132
end
97133
end
98134
end

src/sch/Sch.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,7 @@ struct ProcessorInternalState
11831183
proc_occupancy::Base.RefValue{UInt32}
11841184
time_pressure::Base.RefValue{UInt64}
11851185
cancelled::Set{Int}
1186+
cancel_tokens::Dict{Int,Dagger.CancelToken}
11861187
done::Base.RefValue{Bool}
11871188
end
11881189
struct ProcessorState
@@ -1332,7 +1333,14 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
13321333

13331334
# Execute the task and return its result
13341335
t = @task begin
1336+
# Set up cancellation
1337+
cancel_token = Dagger.CancelToken()
1338+
Dagger.DTASK_CANCEL_TOKEN[] = cancel_token
1339+
lock(istate.queue) do _
1340+
istate.cancel_tokens[thunk_id] = cancel_token
1341+
end
13351342
was_cancelled = false
1343+
13361344
result = try
13371345
do_task(to_proc, task)
13381346
catch err
@@ -1349,6 +1357,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
13491357
# Task was cancelled, so occupancy and pressure are
13501358
# already reduced
13511359
pop!(istate.cancelled, thunk_id)
1360+
delete!(istate.cancel_tokens, thunk_id)
13521361
was_cancelled = true
13531362
end
13541363
end
@@ -1366,6 +1375,9 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
13661375
else
13671376
rethrow(err)
13681377
end
1378+
finally
1379+
# Ensure that any spawned tasks get cleaned up
1380+
Dagger.cancel!(cancel_token)
13691381
end
13701382
end
13711383
lock(istate.queue) do _
@@ -1415,6 +1427,7 @@ function do_tasks(to_proc, return_queue, tasks)
14151427
Dict{Int,Vector{Any}}(),
14161428
Ref(UInt32(0)), Ref(UInt64(0)),
14171429
Set{Int}(),
1430+
Dict{Int,Dagger.CancelToken}(),
14181431
Ref(false))
14191432
runner = start_processor_runner!(istate, uid, return_queue)
14201433
@static if VERSION < v"1.9"
@@ -1656,6 +1669,7 @@ function do_task(to_proc, task_desc)
16561669
sch_handle,
16571670
processor=to_proc,
16581671
task_spec=task_desc,
1672+
cancel_token=Dagger.DTASK_CANCEL_TOKEN[],
16591673
))
16601674

16611675
res = Dagger.with_options(propagated) do

src/task-tls.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ struct DTaskTLS
55
sch_uid::UInt
66
sch_handle::Any # FIXME: SchedulerHandle
77
task_spec::Vector{Any} # FIXME: TaskSpec
8+
cancel_token::CancelToken
89
end
910

1011
const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing)
@@ -22,7 +23,7 @@ get_tls() = DTASK_TLS[]::DTaskTLS
2223
Sets all Dagger TLS variables from `tls`, which may be a `DTaskTLS` or a `NamedTuple`.
2324
"""
2425
function set_tls!(tls)
25-
DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec)
26+
DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec, tls.cancel_token)
2627
end
2728

2829
"""
@@ -40,3 +41,21 @@ Get the current processor executing the current [`DTask`](@ref).
4041
"""
4142
task_processor() = get_tls().processor
4243
@deprecate thunk_processor() task_processor()
44+
45+
"""
46+
task_cancelled() -> Bool
47+
48+
Returns `true` if the current [`DTask`](@ref) has been cancelled, else `false`.
49+
"""
50+
task_cancelled() = get_tls().cancel_token.cancelled[]
51+
52+
"""
53+
task_may_cancel!()
54+
55+
Throws an `InterruptException` if the current [`DTask`](@ref) has been cancelled.
56+
"""
57+
function task_may_cancel!()
58+
if task_cancelled()
59+
throw(InterruptException())
60+
end
61+
end

src/threadproc.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @n
2727
return result[]
2828
catch err
2929
if err isa InterruptException
30+
# Direct interrupt hit us, propagate cancellation signal
31+
# FIXME: We should tell the scheduler that the user hit Ctrl-C
3032
if !istaskdone(task)
31-
# Propagate cancellation signal
3233
Threads.@spawn Base.throwto(task, InterruptException())
3334
end
3435
end

0 commit comments

Comments
 (0)