Skip to content

Commit 515e731

Browse files
committed
parser: Support do-blocks
1 parent 8d29bd8 commit 515e731

File tree

2 files changed

+58
-16
lines changed

2 files changed

+58
-16
lines changed

src/thunk.jl

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -363,37 +363,47 @@ function replace_broadcast(fn::Symbol)
363363
end
364364

365365
function _par(ex::Expr; lazy=true, recur=true, opts=())
366-
if ex.head == :call && recur
367-
f = replace_broadcast(ex.args[1])
368-
if length(ex.args) >= 2 && Meta.isexpr(ex.args[2], :parameters)
369-
args = ex.args[3:end]
370-
kwargs = ex.args[2]
371-
else
372-
args = ex.args[2:end]
373-
kwargs = Expr(:parameters)
366+
body = nothing
367+
if recur && @capture(ex, f_(allargs__)) || @capture(ex, f_(allargs__) do cargs_ body_ end)
368+
f = replace_broadcast(f)
369+
args = filter(arg->!Meta.isexpr(arg, :parameters), allargs)
370+
kwargs = filter(arg->Meta.isexpr(arg, :parameters), allargs)
371+
if !isempty(kwargs)
372+
kwargs = only(kwargs).args
373+
end
374+
if body !== nothing
375+
f = quote
376+
($(args...); $(kwargs...))->$f($(args...); $(kwargs...)) do $cargs
377+
$body
378+
end
379+
end
374380
end
375-
args_ex = _par.(args; lazy=lazy, recur=false)
376-
kwargs_ex = _par.(kwargs.args; lazy=lazy, recur=false)
377381
if lazy
378-
return :(Dagger.delayed($f, $Options(;$(opts...)))($(args_ex...); $(kwargs_ex...)))
382+
return :(Dagger.delayed($f, $Options(;$(opts...)))($(args...); $(kwargs...)))
379383
else
380384
sync_var = Base.sync_varname
381385
@gensym result
382386
return quote
383-
let args = ($(args_ex...),)
384-
$result = $spawn($f, $Options(;$(opts...)), args...; $(kwargs_ex...))
387+
let
388+
$result = $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...))
385389
if $(Expr(:islocal, sync_var))
386390
put!($sync_var, schedule(Task(()->wait($result))))
387391
end
388392
$result
389393
end
390394
end
391395
end
396+
elseif lazy
397+
# Recurse into the expression
398+
return Expr(ex.head, _par_inner.(ex.args, lazy=lazy, recur=recur, opts=opts)...)
392399
else
393-
return Expr(ex.head, _par.(ex.args, lazy=lazy, recur=recur, opts=opts)...)
400+
throw(ArgumentError("Invalid Dagger task expression: $ex"))
394401
end
395402
end
396-
_par(ex; kwargs...) = ex
403+
_par(ex; kwargs...) = throw(ArgumentError("Invalid Dagger task expression: $ex"))
404+
405+
_par_inner(ex; kwargs...) = ex
406+
_par_inner(ex::Expr; kwargs...) = _par(ex; kwargs...)
397407

398408
"""
399409
Dagger.spawn(f, args...; kwargs...) -> DTask

test/thunk.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,39 @@ end
8181
end
8282
@testset "inner macro" begin
8383
A = rand(4)
84-
@test fetch(@spawn sum(@view A[2:3])) sum(@view A[2:3])
84+
t = @spawn sum(@view A[2:3])
85+
@test t isa Dagger.DTask
86+
@test fetch(t) sum(@view A[2:3])
87+
end
88+
@testset "do block" begin
89+
A = rand(4)
90+
91+
t = @spawn sum(A) do a
92+
a + 1
93+
end
94+
@test t isa Dagger.DTask
95+
@test fetch(t) sum(a->a+1, A)
96+
97+
t = @spawn sum(A; dims=1) do a
98+
a + 1
99+
end
100+
@test t isa Dagger.DTask
101+
@test fetch(t) sum(a->a+1, A; dims=1)
102+
103+
do_f = f -> f(42)
104+
t = @spawn do_f() do x
105+
x + 1
106+
end
107+
@test t isa Dagger.DTask
108+
@test fetch(t) == 43
109+
end
110+
@testset "invalid expression" begin
111+
@test_throws LoadError eval(:(@spawn 1))
112+
@test_throws LoadError eval(:(@spawn begin 1 end))
113+
@test_throws LoadError eval(:(@spawn begin
114+
1+1
115+
1+1
116+
end))
85117
end
86118
@testset "waiting" begin
87119
a = @spawn sleep(1)

0 commit comments

Comments
 (0)