Skip to content

Commit ed2493c

Browse files
committed
task-tls: Refactor into DTaskTLS struct
1 parent 5044676 commit ed2493c

File tree

5 files changed

+29
-28
lines changed

5 files changed

+29
-28
lines changed

src/Dagger.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ else
2222
import Base.ScopedValues: ScopedValue, with
2323
end
2424

25+
import TaskLocalValues: TaskLocalValue
26+
2527
if !isdefined(Base, :get_extension)
2628
import Requires: @require
2729
end

src/array/indexing.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import TaskLocalValues: TaskLocalValue
2-
31
### getindex
42

53
struct GetIndex{T,N} <: ArrayOp{T,N}

src/sch/Sch.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1202,7 +1202,7 @@ function proc_states(f::Base.Callable, uid::UInt64)
12021202
end
12031203
end
12041204
proc_states(f::Base.Callable) =
1205-
proc_states(f, task_local_storage(:_dagger_sch_uid)::UInt64)
1205+
proc_states(f, Dagger.get_tls().sch_uid)
12061206

12071207
task_tid_for_processor(::Processor) = nothing
12081208
task_tid_for_processor(proc::Dagger.ThreadProc) = proc.tid

src/sch/dynamic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct SchedulerHandle
1717
end
1818

1919
"Gets the scheduler handle for the currently-executing thunk."
20-
sch_handle() = task_local_storage(:_dagger_sch_handle)::SchedulerHandle
20+
sch_handle() = Dagger.get_tls().sch_handle::SchedulerHandle
2121

2222
"Thrown when the scheduler halts before finishing processing the DAG."
2323
struct SchedulerHaltedException <: Exception end

src/task-tls.jl

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,42 @@
11
# In-Thunk Helpers
22

3+
struct DTaskTLS
4+
processor::Processor
5+
sch_uid::UInt
6+
sch_handle::Any # FIXME: SchedulerHandle
7+
task_spec::Vector{Any} # FIXME: TaskSpec
8+
end
9+
10+
const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing)
11+
312
"""
4-
task_processor()
13+
get_tls() -> DTaskTLS
514
6-
Get the current processor executing the current Dagger task.
15+
Gets all Dagger TLS variable as a `DTaskTLS`.
716
"""
8-
task_processor() = task_local_storage(:_dagger_processor)::Processor
9-
@deprecate thunk_processor() task_processor()
17+
get_tls() = DTASK_TLS[]::DTaskTLS
1018

1119
"""
12-
in_task()
20+
set_tls!(tls)
1321
14-
Returns `true` if currently executing in a [`DTask`](@ref), else `false`.
22+
Sets all Dagger TLS variables from `tls`, which may be a `DTaskTLS` or a `NamedTuple`.
1523
"""
16-
in_task() = haskey(task_local_storage(), :_dagger_sch_uid)
17-
@deprecate in_thunk() in_task()
24+
function set_tls!(tls)
25+
DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec)
26+
end
1827

1928
"""
20-
get_tls()
29+
in_task() -> Bool
2130
22-
Gets all Dagger TLS variable as a `NamedTuple`.
31+
Returns `true` if currently executing in a [`DTask`](@ref), else `false`.
2332
"""
24-
get_tls() = (
25-
sch_uid=task_local_storage(:_dagger_sch_uid),
26-
sch_handle=task_local_storage(:_dagger_sch_handle),
27-
processor=task_processor(),
28-
task_spec=task_local_storage(:_dagger_task_spec),
29-
)
33+
in_task() = DTASK_TLS[] !== nothing
34+
@deprecate in_thunk() in_task()
3035

3136
"""
32-
set_tls!(tls)
37+
task_processor() -> Processor
3338
34-
Sets all Dagger TLS variables from the `NamedTuple` `tls`.
39+
Get the current processor executing the current [`DTask`](@ref).
3540
"""
36-
function set_tls!(tls)
37-
task_local_storage(:_dagger_sch_uid, tls.sch_uid)
38-
task_local_storage(:_dagger_sch_handle, tls.sch_handle)
39-
task_local_storage(:_dagger_processor, tls.processor)
40-
task_local_storage(:_dagger_task_spec, tls.task_spec)
41-
end
41+
task_processor() = get_tls().processor
42+
@deprecate thunk_processor() task_processor()

0 commit comments

Comments
 (0)