Skip to content

Commit b2fd2ab

Browse files
authored
Merge pull request #533 from JuliaParallel/jps/parser-fixes
parser: Assorted improvements
2 parents 684d80c + 86f2f5a commit b2fd2ab

File tree

3 files changed

+108
-20
lines changed

3 files changed

+108
-20
lines changed

src/thunk.jl

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ generated thunks.
306306
macro par(exs...)
307307
opts = exs[1:end-1]
308308
ex = exs[end]
309-
_par(ex; lazy=true, opts=opts)
309+
return esc(_par(ex; lazy=true, opts=opts))
310310
end
311311

312312
"""
@@ -348,7 +348,7 @@ also passes along any options in an `Options` struct. For example,
348348
macro spawn(exs...)
349349
opts = exs[1:end-1]
350350
ex = exs[end]
351-
_par(ex; lazy=false, opts=opts)
351+
return esc(_par(ex; lazy=false, opts=opts))
352352
end
353353

354354
struct ExpandedBroadcast{F} end
@@ -363,39 +363,62 @@ function replace_broadcast(fn::Symbol)
363363
end
364364

365365
function _par(ex::Expr; lazy=true, recur=true, opts=())
366-
if ex.head == :call && recur
367-
f = replace_broadcast(ex.args[1])
368-
if length(ex.args) >= 2 && Meta.isexpr(ex.args[2], :parameters)
369-
args = ex.args[3:end]
370-
kwargs = ex.args[2]
371-
else
372-
args = ex.args[2:end]
373-
kwargs = Expr(:parameters)
366+
f = nothing
367+
body = nothing
368+
arg1 = nothing
369+
if recur && @capture(ex, f_(allargs__)) || @capture(ex, f_(allargs__) do cargs_ body_ end) || @capture(ex, allargs__->body_) || @capture(ex, arg1_[allargs__])
370+
f = replace_broadcast(f)
371+
if arg1 !== nothing
372+
# Indexing (A[2,3])
373+
f = Base.getindex
374+
pushfirst!(allargs, arg1)
375+
end
376+
args = filter(arg->!Meta.isexpr(arg, :parameters), allargs)
377+
kwargs = filter(arg->Meta.isexpr(arg, :parameters), allargs)
378+
if !isempty(kwargs)
379+
kwargs = only(kwargs).args
380+
end
381+
if body !== nothing
382+
if f !== nothing
383+
f = quote
384+
($(args...); $(kwargs...))->$f($(args...); $(kwargs...)) do $cargs
385+
$body
386+
end
387+
end
388+
else
389+
f = quote
390+
($(args...); $(kwargs...))->begin
391+
$body
392+
end
393+
end
394+
end
374395
end
375-
opts = esc.(opts)
376-
args_ex = _par.(args; lazy=lazy, recur=false)
377-
kwargs_ex = _par.(kwargs.args; lazy=lazy, recur=false)
378396
if lazy
379-
return :(Dagger.delayed($(esc(f)), $Options(;$(opts...)))($(args_ex...); $(kwargs_ex...)))
397+
return :(Dagger.delayed($f, $Options(;$(opts...)))($(args...); $(kwargs...)))
380398
else
381-
sync_var = esc(Base.sync_varname)
399+
sync_var = Base.sync_varname
382400
@gensym result
383401
return quote
384-
let args = ($(args_ex...),)
385-
$result = $spawn($(esc(f)), $Options(;$(opts...)), args...; $(kwargs_ex...))
402+
let
403+
$result = $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...))
386404
if $(Expr(:islocal, sync_var))
387405
put!($sync_var, schedule(Task(()->wait($result))))
388406
end
389407
$result
390408
end
391409
end
392410
end
411+
elseif lazy
412+
# Recurse into the expression
413+
return Expr(ex.head, _par_inner.(ex.args, lazy=lazy, recur=recur, opts=opts)...)
393414
else
394-
return Expr(ex.head, _par.(ex.args, lazy=lazy, recur=recur, opts=opts)...)
415+
throw(ArgumentError("Invalid Dagger task expression: $ex"))
395416
end
396417
end
397-
_par(ex::Symbol; kwargs...) = esc(ex)
398-
_par(ex; kwargs...) = ex
418+
_par(ex; kwargs...) = throw(ArgumentError("Invalid Dagger task expression: $ex"))
419+
420+
_par_inner(ex; kwargs...) = ex
421+
_par_inner(ex::Expr; kwargs...) = _par(ex; kwargs...)
399422

400423
"""
401424
Dagger.spawn(f, args...; kwargs...) -> DTask

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__
3131
pushfirst!(LOAD_PATH, joinpath(@__DIR__, ".."))
3232
using Pkg
3333
Pkg.activate(@__DIR__)
34+
Pkg.instantiate()
3435

3536
using ArgParse
3637
s = ArgParseSettings(description = "Dagger Testsuite")

test/thunk.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,70 @@ end
7979
@test fetch(@spawn A .+ B) A .+ B
8080
@test fetch(@spawn A .* B) A .* B
8181
end
82+
@testset "inner macro" begin
83+
A = rand(4)
84+
t = @spawn sum(@view A[2:3])
85+
@test t isa Dagger.DTask
86+
@test fetch(t) sum(@view A[2:3])
87+
end
88+
@testset "do block" begin
89+
A = rand(4)
90+
91+
t = @spawn sum(A) do a
92+
a + 1
93+
end
94+
@test t isa Dagger.DTask
95+
@test fetch(t) sum(a->a+1, A)
96+
97+
t = @spawn sum(A; dims=1) do a
98+
a + 1
99+
end
100+
@test t isa Dagger.DTask
101+
@test fetch(t) sum(a->a+1, A; dims=1)
102+
103+
do_f = f -> f(42)
104+
t = @spawn do_f() do x
105+
x + 1
106+
end
107+
@test t isa Dagger.DTask
108+
@test fetch(t) == 43
109+
end
110+
@testset "anonymous direct call" begin
111+
A = rand(4)
112+
113+
t = @spawn A->sum(A)
114+
@test t isa Dagger.DTask
115+
@test fetch(t) == sum(A)
116+
117+
t = @spawn A->sum(A; dims=1)
118+
@test t isa Dagger.DTask
119+
@test fetch(t) == sum(A; dims=1)
120+
end
121+
@testset "getindex" begin
122+
A = rand(4, 4)
123+
124+
t = @spawn A[1, 2]
125+
@test t isa Dagger.DTask
126+
@test fetch(t) == A[1, 2]
127+
128+
B = Dagger.@spawn rand(4, 4)
129+
t = @spawn B[1, 2]
130+
@test t isa Dagger.DTask
131+
@test fetch(t) == fetch(B)[1, 2]
132+
133+
R = Ref(42)
134+
t = @spawn R[]
135+
@test t isa Dagger.DTask
136+
@test fetch(t) == 42
137+
end
138+
@testset "invalid expression" begin
139+
@test_throws LoadError eval(:(@spawn 1))
140+
@test_throws LoadError eval(:(@spawn begin 1 end))
141+
@test_throws LoadError eval(:(@spawn begin
142+
1+1
143+
1+1
144+
end))
145+
end
82146
@testset "waiting" begin
83147
a = @spawn sleep(1)
84148
@test !isready(a)

0 commit comments

Comments
 (0)