Skip to content

Commit 488ae7a

Browse files
committed
cancellation: Add cancel token support
1 parent 0763d99 commit 488ae7a

File tree

5 files changed

+75
-5
lines changed

5 files changed

+75
-5
lines changed

src/Dagger.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,16 @@ include("processor.jl")
4848
include("threadproc.jl")
4949
include("context.jl")
5050
include("utils/processors.jl")
51+
include("dtask.jl")
52+
include("cancellation.jl")
5153
include("task-tls.jl")
5254
include("scopes.jl")
5355
include("utils/scopes.jl")
54-
include("dtask.jl")
5556
include("queue.jl")
5657
include("thunk.jl")
5758
include("submission.jl")
5859
include("chunks.jl")
5960
include("memory-spaces.jl")
60-
include("cancellation.jl")
6161

6262
# Task scheduling
6363
include("compute.jl")

src/cancellation.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,38 @@
1+
# DTask-level cancellation
2+
3+
struct CancelToken
4+
cancelled::Base.RefValue{Bool}
5+
event::Base.Event
6+
end
7+
CancelToken() = CancelToken(Ref(false), Base.Event())
8+
function cancel!(token::CancelToken)
9+
token.cancelled[] = true
10+
notify(token.event)
11+
return
12+
end
13+
is_cancelled(token::CancelToken) = token.cancelled[]
14+
Base.wait(token::CancelToken) = wait(token.event)
15+
# TODO: Enable this for safety
16+
#Serialization.serialize(io::AbstractSerializer, ::CancelToken) =
17+
# throw(ConcurrencyViolationError("Cannot serialize a CancelToken"))
18+
19+
const DTASK_CANCEL_TOKEN = TaskLocalValue{Union{CancelToken,Nothing}}(()->nothing)
20+
21+
function clone_cancel_token_remote(orig_token::CancelToken, wid::Integer)
22+
remote_token = remotecall_fetch(wid) do
23+
return poolset(CancelToken())
24+
end
25+
errormonitor_tracked("remote cancel_token communicator", Threads.@spawn begin
26+
wait(orig_token)
27+
@dagdebug nothing :cancel "Cancelling remote token on worker $wid"
28+
MemPool.access_ref(remote_token) do remote_token
29+
cancel!(remote_token)
30+
end
31+
end)
32+
end
33+
34+
# Global-level cancellation
35+
136
"""
237
cancel!(task::DTask; force::Bool=false, halt_sch::Bool=false)
338
@@ -80,11 +115,11 @@ function _cancel!(state, tid, force, halt_sch)
80115
Tf === typeof(Sch.eager_thunk) && continue
81116
istaskdone(task) && continue
82117
any_cancelled = true
83-
@dagdebug tid :cancel "Cancelling running task ($Tf)"
84118
if force
85119
@dagdebug tid :cancel "Interrupting running task ($Tf)"
86120
Threads.@spawn Base.throwto(task, InterruptException())
87121
else
122+
@dagdebug tid :cancel "Cancelling running task ($Tf)"
88123
# Tell the processor to just drop this task
89124
task_occupancy = task_spec[4]
90125
time_util = task_spec[2]
@@ -93,6 +128,7 @@ function _cancel!(state, tid, force, halt_sch)
93128
push!(istate.cancelled, tid)
94129
to_proc = istate.proc
95130
put!(istate.return_queue, (myid(), to_proc, tid, (InterruptException(), nothing)))
131+
cancel!(istate.cancel_tokens[tid])
96132
end
97133
end
98134
end

src/sch/Sch.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,7 @@ struct ProcessorInternalState
11771177
proc_occupancy::Base.RefValue{UInt32}
11781178
time_pressure::Base.RefValue{UInt64}
11791179
cancelled::Set{Int}
1180+
cancel_tokens::Dict{Int,Dagger.CancelToken}
11801181
done::Base.RefValue{Bool}
11811182
end
11821183
struct ProcessorState
@@ -1326,7 +1327,14 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
13261327

13271328
# Execute the task and return its result
13281329
t = @task begin
1330+
# Set up cancellation
1331+
cancel_token = Dagger.CancelToken()
1332+
Dagger.DTASK_CANCEL_TOKEN[] = cancel_token
1333+
lock(istate.queue) do _
1334+
istate.cancel_tokens[thunk_id] = cancel_token
1335+
end
13291336
was_cancelled = false
1337+
13301338
result = try
13311339
do_task(to_proc, task)
13321340
catch err
@@ -1343,6 +1351,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
13431351
# Task was cancelled, so occupancy and pressure are
13441352
# already reduced
13451353
pop!(istate.cancelled, thunk_id)
1354+
delete!(istate.cancel_tokens, thunk_id)
13461355
was_cancelled = true
13471356
end
13481357
end
@@ -1360,6 +1369,9 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
13601369
else
13611370
rethrow(err)
13621371
end
1372+
finally
1373+
# Ensure that any spawned tasks get cleaned up
1374+
Dagger.cancel!(cancel_token)
13631375
end
13641376
end
13651377
lock(istate.queue) do _
@@ -1409,6 +1421,7 @@ function do_tasks(to_proc, return_queue, tasks)
14091421
Dict{Int,Vector{Any}}(),
14101422
Ref(UInt32(0)), Ref(UInt64(0)),
14111423
Set{Int}(),
1424+
Dict{Int,Dagger.CancelToken}(),
14121425
Ref(false))
14131426
runner = start_processor_runner!(istate, uid, return_queue)
14141427
@static if VERSION < v"1.9"
@@ -1650,6 +1663,7 @@ function do_task(to_proc, task_desc)
16501663
sch_handle,
16511664
processor=to_proc,
16521665
task_spec=task_desc,
1666+
cancel_token=Dagger.DTASK_CANCEL_TOKEN[],
16531667
))
16541668

16551669
res = Dagger.with_options(propagated) do

src/task-tls.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ struct DTaskTLS
55
sch_uid::UInt
66
sch_handle::Any # FIXME: SchedulerHandle
77
task_spec::Vector{Any} # FIXME: TaskSpec
8+
cancel_token::CancelToken
89
end
910

1011
const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing)
@@ -22,7 +23,7 @@ get_tls() = DTASK_TLS[]::DTaskTLS
2223
Sets all Dagger TLS variables from `tls`, which may be a `DTaskTLS` or a `NamedTuple`.
2324
"""
2425
function set_tls!(tls)
25-
DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec)
26+
DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec, tls.cancel_token)
2627
end
2728

2829
"""
@@ -40,3 +41,21 @@ Get the current processor executing the current [`DTask`](@ref).
4041
"""
4142
task_processor() = get_tls().processor
4243
@deprecate thunk_processor() task_processor()
44+
45+
"""
46+
task_cancelled() -> Bool
47+
48+
Returns `true` if the current [`DTask`](@ref) has been cancelled, else `false`.
49+
"""
50+
task_cancelled() = get_tls().cancel_token.cancelled[]
51+
52+
"""
53+
task_may_cancel!()
54+
55+
Throws an `InterruptException` if the current [`DTask`](@ref) has been cancelled.
56+
"""
57+
function task_may_cancel!()
58+
if task_cancelled()
59+
throw(InterruptException())
60+
end
61+
end

src/threadproc.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @n
2727
return result[]
2828
catch err
2929
if err isa InterruptException
30+
# Direct interrupt hit us, propagate cancellation signal
31+
# FIXME: We should tell the scheduler that the user hit Ctrl-C
3032
if !istaskdone(task)
31-
# Propagate cancellation signal
3233
Threads.@spawn Base.throwto(task, InterruptException())
3334
end
3435
end

0 commit comments

Comments
 (0)