Skip to content

Commit 1d8d8d2

Browse files
authored
Merge pull request #5 from JuliaFolds2/atomic
Atomic storage, threadpool support
2 parents 283b178 + cdea68a commit 1d8d8d2

File tree

6 files changed

+52
-19
lines changed

6 files changed

+52
-19
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ jobs:
1010
fail-fast: false
1111
matrix:
1212
version:
13-
- '1.6'
1413
- '1.9'
1514
- '1.10.0'
1615
- 'nightly'

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ authors = ["Mason Protter <[email protected]>"]
44
version = "0.1.3"
55

66
[compat]
7-
julia = "1.6"
7+
julia = "1.9"
88

99
[extras]
1010
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
StableTasks is a simple package with one main API `StableTasks.@spawn` (not exported by default).
44

5-
It works like `Threads.@spawn`, except it is *type stable* to `fetch` from (and it does not yet support threadpools
6-
other than the default threadpool).
5+
It works like `Threads.@spawn`, except it is *type stable* to `fetch` from.
76

87
``` julia
98
julia> Core.Compiler.return_type(() -> fetch(StableTasks.@spawn 1 + 1), Tuple{})

src/StableTasks.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@ module StableTasks
33
macro spawn end
44
macro spawnat end
55

6-
using Base: RefValue
6+
mutable struct AtomicRef{T}
7+
@atomic x::T
8+
AtomicRef{T}() where {T} = new{T}()
9+
AtomicRef(x::T) where {T} = new{T}(x)
10+
AtomicRef{T}(x) where {T} = new{T}(convert(T, x))
11+
end
12+
713
struct StableTask{T}
814
t::Task
9-
ret::RefValue{T}
15+
ret::AtomicRef{T}
1016
end
1117

1218
include("internals.jl")

src/internals.jl

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
module Internals
22

3-
import StableTasks: @spawn, @spawnat, StableTask
3+
import StableTasks: @spawn, @spawnat, StableTask, AtomicRef
4+
5+
Base.getindex(r::AtomicRef) = @atomic r.x
6+
Base.setindex!(r::AtomicRef{T}, x) where {T} = @atomic r.x = convert(T, x)
47

58
function Base.fetch(t::StableTask{T}) where {T}
69
fetch(t.t)
@@ -25,41 +28,58 @@ Base.schedule(t::StableTask) = (schedule(t.t); t)
2528
Base.schedule(t, val; error=false) = (schedule(t.t, val; error); t)
2629

2730
"""
28-
Similar to `Threads.@spawn` but type-stable. Creates a `Task` and schedules it to run on any available thread in the `:default` threadpool.
31+
@spawn [:default|:interactive] expr
32+
33+
Similar to `Threads.@spawn` but type-stable. Creates a `Task` and schedules it to run on any available
34+
thread in the specified threadpool (defaults to the `:default` threadpool).
2935
"""
30-
macro spawn(ex)
36+
macro spawn(args...)
37+
tp = QuoteNode(:default)
38+
na = length(args)
39+
if na == 2
40+
ttype, ex = args
41+
if ttype isa QuoteNode
42+
ttype = ttype.value
43+
if ttype !== :interactive && ttype !== :default
44+
throw(ArgumentError("unsupported threadpool in StableTasks.@spawn: $ttype"))
45+
end
46+
tp = QuoteNode(ttype)
47+
else
48+
tp = ttype
49+
end
50+
elseif na == 1
51+
ex = args[1]
52+
else
53+
throw(ArgumentError("wrong number of arguments in @spawn"))
54+
end
55+
3156
letargs = _lift_one_interp!(ex)
3257

3358
thunk = replace_linenums!(:(() -> ($(esc(ex)))), __source__)
3459
var = esc(Base.sync_varname) # This is for the @sync macro which sets a local variable whose name is
3560
# the symbol bound to Base.sync_varname
3661
# I asked on slack and this is apparently safe to consider a public API
37-
set_pool = if VERSION < v"1.9"
38-
nothing
39-
else
40-
:(Threads._spawn_set_thrpool(task, :default))
41-
end
4262
quote
4363
let $(letargs...)
4464
f = $thunk
4565
T = Core.Compiler.return_type(f, Tuple{})
46-
ref = Ref{T}()
66+
ref = AtomicRef{T}()
4767
f_wrap = () -> (ref[] = f(); nothing)
4868
task = Task(f_wrap)
4969
task.sticky = false
50-
$set_pool
70+
Threads._spawn_set_thrpool(task, $(esc(tp)))
5171
if $(Expr(:islocal, var))
5272
put!($var, task) # Sync will set up a Channel, and we want our task to be in there.
5373
end
5474
schedule(task)
55-
StableTask(task, ref)
75+
StableTask{T}(task, ref)
5676
end
5777
end
5878
end
5979

6080
"""
6181
Similar to `StableTasks.@spawn` but creates a **sticky** `Task` and schedules it to run on the thread with the given id (`thrdid`).
62-
The task is guaranteed to stay on this thread (it won't migrate to another thread).
82+
The task is guaranteed to stay on this thread (it won't migrate to another thread).
6383
"""
6484
macro spawnat(thrdid, ex)
6585
letargs = _lift_one_interp!(ex)
@@ -81,7 +101,7 @@ macro spawnat(thrdid, ex)
81101
let $(letargs...)
82102
thunk = $thunk
83103
RT = Core.Compiler.return_type(thunk, Tuple{})
84-
ret = Ref{RT}()
104+
ret = AtomicRef{RT}()
85105
thunk_wrap = () -> (ret[] = thunk(); nothing)
86106
local task = Task(thunk_wrap)
87107
task.sticky = true

test/runtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ using StableTasks: @spawn, @spawnat
66
t = @eval @spawn inv([1 2 ; 3 4])
77
@test inv([1 2 ; 3 4]) == @inferred fetch(t)
88

9+
@test 2 == @inferred fetch(@spawn :interactive 1 + 1)
10+
t = @eval @spawn :interactive inv([1 2 ; 3 4])
11+
@test inv([1 2 ; 3 4]) == @inferred fetch(t)
12+
13+
s = :default
14+
@test 2 == @inferred fetch(@spawn s 1 + 1)
15+
t = @eval @spawn $(QuoteNode(s)) inv([1 2 ; 3 4])
16+
@test inv([1 2 ; 3 4]) == @inferred fetch(t)
17+
918
@test 2 == @inferred fetch(@spawnat 1 1 + 1)
1019
t = @eval @spawnat 1 inv([1 2 ; 3 4])
1120
@test inv([1 2 ; 3 4]) == @inferred fetch(t)

0 commit comments

Comments
 (0)