Skip to content

Commit 00f6625

Browse files
committed
spawnat
1 parent 3701655 commit 00f6625

File tree

5 files changed

+73
-10
lines changed

5 files changed

+73
-10
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Manifest.toml
2+
.vscode

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,12 @@ versus
1515
julia> Core.Compiler.return_type(() -> fetch(Threads.@spawn 1 + 1), Tuple{})
1616
Any
1717
```
18+
19+
The package also provides `StableTasks.@spawnat` (not exported), which is similar to `StableTasks.@spawn` but creates a *sticky* task (it won't migrate) on a specific thread.
20+
21+
```julia
22+
julia> t = StableTasks.@spawnat 4 Threads.threadid();
23+
24+
julia> fetch(t)
25+
4
26+
```

src/StableTasks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module StableTasks
22

33
macro spawn end
4+
macro spawnat end
45

56
using Base: RefValue
67
struct StableTask{T}

src/internals.jl

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
module Internals
1+
module Internals
22

3-
import StableTasks: @spawn, StableTask
3+
import StableTasks: @spawn, @spawnat, StableTask
44

55
function Base.fetch(t::StableTask{T}) where {T}
66
fetch(t.t)
@@ -26,14 +26,12 @@ Base.schedule(t, val; error=false) = (schedule(t.t, val; error); t)
2626

2727

2828
macro spawn(ex)
29-
tp = QuoteNode(:default)
30-
3129
letargs = _lift_one_interp!(ex)
3230

33-
thunk = replace_linenums!(:(()->($(esc(ex)))), __source__)
31+
thunk = replace_linenums!(:(() -> ($(esc(ex)))), __source__)
3432
var = esc(Base.sync_varname) # This is for the @sync macro which sets a local variable whose name is
35-
# the symbol bound to Base.sync_varname
36-
# I asked on slack and this is apparently safe to consider a public API
33+
# the symbol bound to Base.sync_varname
34+
# I asked on slack and this is apparently safe to consider a public API
3735
quote
3836
let $(letargs...)
3937
f = $thunk
@@ -51,6 +49,39 @@ macro spawn(ex)
5149
end
5250
end
5351

52+
macro spawnat(thrdid, ex)
53+
letargs = _lift_one_interp!(ex)
54+
55+
thunk = replace_linenums!(:(() -> ($(esc(ex)))), __source__)
56+
var = esc(Base.sync_varname)
57+
58+
tid = esc(thrdid)
59+
@static if VERSION < v"1.9"
60+
nt = :(Threads.nthreads())
61+
else
62+
nt = :(Threads.maxthreadid())
63+
end
64+
quote
65+
if $tid < 1 || $tid > $nt
66+
throw(ArgumentError("Invalid thread id ($($tid)). Must be between in " *
67+
"1:(total number of threads), i.e. $(1:$nt)."))
68+
end
69+
let $(letargs...)
70+
thunk = $thunk
71+
RT = Core.Compiler.return_type(thunk, Tuple{})
72+
ret = Ref{RT}()
73+
thunk_wrap = () -> (ret[] = thunk(); nothing)
74+
local task = Task(thunk_wrap)
75+
task.sticky = true
76+
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), task, $tid - 1)
77+
if $(Expr(:islocal, var))
78+
put!($var, task)
79+
end
80+
schedule(task)
81+
StableTask(task, ret)
82+
end
83+
end
84+
end
5485

5586
# Copied from base rather than calling it directly because who knows if it'll change in the future
5687
function _lift_one_interp!(e)
@@ -74,7 +105,7 @@ function _lift_one_interp_helper(expr::Expr, in_quote_context, letargs)
74105
elseif expr.head === :macrocall
75106
return expr # Don't recur into macro calls, since some other macros use $
76107
end
77-
for (i,e) in enumerate(expr.args)
108+
for (i, e) in enumerate(expr.args)
78109
expr.args[i] = _lift_one_interp_helper(e, in_quote_context, letargs)
79110
end
80111
expr

test/runtests.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
using Test, StableTasks
2-
using StableTasks: @spawn
2+
using StableTasks: @spawn, @spawnat
33

44
@testset "Type stability" begin
5-
@test 2 ==@inferred fetch(@spawn 1 + 1)
5+
@test 2 == @inferred fetch(@spawn 1 + 1)
66
t = @eval @spawn inv([1 2 ; 3 4])
77
@test inv([1 2 ; 3 4]) == @inferred fetch(t)
8+
9+
@test 2 == @inferred fetch(@spawnat 1 1 + 1)
10+
t = @eval @spawnat 1 inv([1 2 ; 3 4])
11+
@test inv([1 2 ; 3 4]) == @inferred fetch(t)
812
end
913

1014
@testset "API funcs" begin
@@ -22,4 +26,20 @@ end
2226
@test r[] == 0
2327
end
2428
@test r[] == 1
29+
30+
T = @spawnat 1 rand(Bool)
31+
@test isnothing(wait(T))
32+
@test istaskdone(T)
33+
@test istaskfailed(T) == false
34+
@test istaskstarted(T)
35+
@test fetch(@spawnat 1 Threads.threadid()) == 1
36+
r = Ref(0)
37+
@sync begin
38+
@spawnat 1 begin
39+
sleep(5)
40+
r[] = 1
41+
end
42+
@test r[] == 0
43+
end
44+
@test r[] == 1
2545
end

0 commit comments

Comments
 (0)