Skip to content

Commit b9ce129

Browse files
committed
Add keyword argument support
APIs like `delayed` and `spawn` assumed that passed kwargs were to be treated as options to the scheduler, which is both somewhat confusing for users, and precludes passing kwargs to user functions. This commit changes those APIs, as well as `@spawn`, to instead pass kwargs directly to the user's function. Options are now passed in an `Options` struct to `delayed` and `spawn` as the second argument (the first being the function), while `@spawn` still keeps them before the call (which is generally more convenient). Internally, `Thunk`'s `inputs` field is now a `Vector{Pair{Union{Symbol,Nothing},Any}}`, where the second element of each pair is the argument, while the first element is a position; if `nothing`, it's a positional argument, and if a `Symbol`, then it's a kwarg.
1 parent 4b83c4b commit b9ce129

19 files changed

+216
-125
lines changed

docs/src/checkpointing.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,17 @@ Let's see how we'd modify the above example to use checkpointing:
5454

5555
```julia
5656
using Serialization
57+
5758
X = compute(randn(Blocks(128,128), 1024, 1024))
58-
Y = [delayed(sum; options=Dagger.Sch.ThunkOptions(;
59-
checkpoint=(thunk,result)->begin
59+
Y = [delayed(sum; checkpoint=(thunk,result)->begin
6060
open("checkpoint-$idx.bin", "w") do io
6161
serialize(io, collect(result))
6262
end
6363
end, restore=(thunk)->begin
6464
open("checkpoint-$idx.bin", "r") do io
6565
Dagger.tochunk(deserialize(io))
6666
end
67-
end))(chunk) for (idx,chunk) in enumerate(X.chunks)]
67+
end)(chunk) for (idx,chunk) in enumerate(X.chunks)]
6868
inner(x...) = sqrt(sum(x))
6969
Z = delayed(inner)(Y...)
7070
z = collect(Z)

docs/src/index.md

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,34 @@
22

33
## Usage
44

5-
The main function for using Dagger is `spawn`:
5+
The main entrypoint to Dagger is `@spawn`:
66

7-
`Dagger.spawn(f, args...; options...)`
7+
`Dagger.@spawn [option=value]... f(args...; kwargs...)`
88

9-
or `@spawn` for the more convenient macro form:
9+
or `spawn` if it's more convenient:
1010

11-
`Dagger.@spawn [option=value]... f(args...)`
11+
`Dagger.spawn(f, Dagger.Options(options), args...; kwargs...)`
1212

1313
When called, it creates an `EagerThunk` (also known as a "thunk" or "task")
14-
object representing a call to function `f` with the arguments `args`. If it is
15-
called with other thunks as inputs, such as in `Dagger.@spawn f(Dagger.@spawn
16-
g())`, then the function `f` gets passed the results of those input thunks. If
17-
those thunks aren't yet finished executing, then the execution of `f` waits on
18-
all of its input thunks to complete before executing.
14+
object representing a call to function `f` with the arguments `args` and
15+
keyword arguments `kwargs`. If it is called with other thunks as args/kwargs,
16+
such as in `Dagger.@spawn f(Dagger.@spawn g())`, then the function `f` gets
17+
passed the results of those input thunks, once they're available. If those
18+
thunks aren't yet finished executing, then the execution of `f` waits on all of
19+
its input thunks to complete before executing.
1920

2021
The key point is that, for each argument to a thunk, if the argument is an
2122
`EagerThunk`, it'll be executed before this node and its result will be passed
2223
into the function `f`. If the argument is *not* an `EagerThunk` (instead, some
2324
other type of Julia object), it'll be passed as-is to the function `f`.
2425

25-
Thunks don't accept regular keyword arguments for the function `f`. Instead,
26-
the `options` kwargs are passed to the scheduler to control its behavior:
26+
The `Options` struct in the second argument position is optional; if provided,
27+
it is passed to the scheduler to control its behavior. `Options` contains a
28+
`NamedTuple` of option key-value pairs, which can be any of:
2729
- Any field in `Dagger.Sch.ThunkOptions` (see [Scheduler and Thunk options](@ref))
2830
- `meta::Bool` -- Pass the input `Chunk` objects themselves to `f` and not the value contained in them
2931

30-
There are also some extra kwargs that can be passed, although they're considered advanced options to be used only by developers or library authors:
32+
There are also some extra optionss that can be passed, although they're considered advanced options to be used only by developers or library authors:
3133
- `get_result::Bool` -- return the actual result to the scheduler instead of `Chunk` objects. Used when `f` explicitly constructs a Chunk or when return value is small (e.g. in case of reduce)
3234
- `persist::Bool` -- the result of this Thunk should not be released after it becomes unused in the DAG
3335
- `cache::Bool` -- cache the result of this Thunk such that if the thunk is evaluated again, one can just reuse the cached value. If it’s been removed from cache, recompute the value.
@@ -133,18 +135,18 @@ via `@par` or `delayed`. The above computation can be executed with the lazy
133135
API by substituting `@spawn` with `@par` and `fetch` with `collect`:
134136

135137
```julia
136-
p = @par add1(4)
137-
q = @par add2(p)
138-
r = @par add1(3)
139-
s = @par combine(p, q, r)
138+
p = Dagger.@par add1(4)
139+
q = Dagger.@par add2(p)
140+
r = Dagger.@par add1(3)
141+
s = Dagger.@par combine(p, q, r)
140142

141143
@assert collect(s) == 16
142144
```
143145

144146
or similarly, in block form:
145147

146148
```julia
147-
s = @par begin
149+
s = Dagger.@par begin
148150
p = add1(4)
149151
q = add2(p)
150152
r = add1(3)
@@ -159,7 +161,7 @@ operation, you can call `compute` on the thunk. This will return a `Chunk`
159161
object which references the result (see [Chunks](@ref) for more details):
160162

161163
```julia
162-
x = @par 1+2
164+
x = Dagger.@par 1+2
163165
cx = compute(x)
164166
cx::Chunk
165167
@assert collect(cx) == 3
@@ -207,7 +209,7 @@ Scheduler options can be constructed and passed to `collect()` or `compute()`
207209
as the keyword argument `options` for lazy API usage:
208210

209211
```julia
210-
t = @par 1+2
212+
t = Dagger.@par 1+2
211213
opts = Dagger.Sch.SchedulerOptions(;single=1) # Execute on worker 1
212214

213215
compute(t; options=opts)
@@ -221,10 +223,9 @@ Thunk options can be passed to `@spawn/spawn`, `@par`, and `delayed` similarly:
221223
# Execute on worker 1
222224

223225
Dagger.@spawn single=1 1+2
224-
Dagger.spawn(+, 1, 2; single=1)
226+
Dagger.spawn(+, Dagger.Options(;single=1), 1, 2)
225227

226-
opts = Dagger.Sch.ThunkOptions(;single=1)
227-
delayed(+)(1, 2; options=opts)
228+
delayed(+; single=1)(1, 2)
228229
```
229230

230231
### Core vs. Worker Schedulers

docs/src/propagation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Option Propagation
22

3-
Most options passed to Dagger are passed via `delayed` or `Dagger.@spawn`
3+
Most options passed to Dagger are passed via `@spawn/spawn` or `delayed`
44
directly. This works well when an option only needs to be set for a single
55
thunk, but is cumbersome when the same option needs to be set on multiple
66
thunks, or set recursively on thunks spawned within other thunks. Thankfully,

src/array/darray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ function thunkize(ctx::Context, c::DArray; persist=true)
251251
if persist
252252
foreach(persist!, thunks)
253253
end
254-
Thunk(thunks...; meta=true) do results...
254+
Thunk(map(thunk->nothing=>thunk, thunks)...; meta=true) do results...
255255
t = eltype(results[1])
256256
DArray(t, dmn, dmnchunks,
257257
reshape(Union{Chunk,Thunk}[results...], sz))

src/array/map-reduce.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function stage(ctx::Context, node::Map)
1919
f = node.f
2020
for i=eachindex(domains)
2121
inps = map(x->chunks(x)[i], inputs)
22-
thunks[i] = Thunk((args...) -> map(f, args...), inps...)
22+
thunks[i] = Thunk((args...) -> map(f, args...), map(inp->nothing=>inp, inps)...)
2323
end
2424
DArray(Any, domain(primary), domainchunks(primary), thunks)
2525
end
@@ -40,8 +40,8 @@ end
4040

4141
function stage(ctx::Context, r::ReduceBlock)
4242
inp = stage(ctx, r.input)
43-
reduced_parts = map(x -> Thunk(r.op, x; get_result=r.get_result), chunks(inp))
44-
Thunk((xs...) -> r.op_master(xs), reduced_parts...; meta=true)
43+
reduced_parts = map(x -> Thunk(r.op, nothing=>x; get_result=r.get_result), chunks(inp))
44+
Thunk((xs...) -> r.op_master(xs), map(part->nothing=>part, reduced_parts)...; meta=true)
4545
end
4646

4747
reduceblock_async(f, x::ArrayOp; get_result=true) = ReduceBlock(f, f, x, get_result)
@@ -126,10 +126,10 @@ function stage(ctx::Context, r::Reducedim)
126126
inp = cached_stage(ctx, r.input)
127127
thunks = let op = r.op, dims=r.dims
128128
# do reducedim on each block
129-
tmp = map(p->Thunk(b->reduce(op,b,dims=dims), p), chunks(inp))
129+
tmp = map(p->Thunk(b->reduce(op,b,dims=dims), nothing=>p), chunks(inp))
130130
# combine the results in tree fashion
131131
treereducedim(tmp, r.dims) do x,y
132-
Thunk(op, x,y)
132+
Thunk(op, nothing=>x, nothing=>y)
133133
end
134134
end
135135
c = domainchunks(inp)

src/array/matrix.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ function size(x::Transpose)
1717
end
1818

1919
transpose(x::ArrayOp) = Transpose(transpose, x)
20-
transpose(x::Union{Chunk, Thunk}) = Thunk(transpose, x)
20+
transpose(x::Union{Chunk, Thunk}) = Thunk(transpose, nothing=>x)
2121

2222
adjoint(x::ArrayOp) = Transpose(adjoint, x)
23-
adjoint(x::Union{Chunk, Thunk}) = Thunk(adjoint, x)
23+
adjoint(x::Union{Chunk, Thunk}) = Thunk(adjoint, nothing=>x)
2424

2525
function adjoint(x::ArrayDomain{2})
2626
d = indexes(x)
@@ -91,8 +91,8 @@ function (+)(a::ArrayDomain, b::ArrayDomain)
9191
a
9292
end
9393

94-
(*)(a::Union{Chunk, Thunk}, b::Union{Chunk, Thunk}) = Thunk(*, a,b)
95-
(+)(a::Union{Chunk, Thunk}, b::Union{Chunk, Thunk}) = Thunk(+, a,b)
94+
(*)(a::Union{Chunk, Thunk}, b::Union{Chunk, Thunk}) = Thunk(*, nothing=>a, nothing=>b)
95+
(+)(a::Union{Chunk, Thunk}, b::Union{Chunk, Thunk}) = Thunk(+, nothing=>a, nothing=>b)
9696

9797
# we define our own matmat and matvec multiply
9898
# for computing the new domains and thunks.
@@ -211,7 +211,7 @@ end
211211
function _scale(l, r)
212212
res = similar(r, Any)
213213
for i=1:length(l)
214-
res[i,:] = map(x->Thunk((a,b) -> Diagonal(a)*b, l[i], x), r[i,:])
214+
res[i,:] = map(x->Thunk((a,b) -> Diagonal(a)*b, nothing=>l[i], nothing=>x), r[i,:])
215215
end
216216
res
217217
end

src/array/operators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ Base.@deprecate mappart(args...) mapchunk(args...)
107107
function stage(ctx::Context, node::MapChunk)
108108
inputs = map(x->cached_stage(ctx, x), node.input)
109109
thunks = map(map(chunks, inputs)...) do ps...
110-
Thunk(node.f, ps...)
110+
Thunk(node.f, map(p->nothing=>p, ps)...)
111111
end
112112

113113
DArray(Any, domain(inputs[1]), domainchunks(inputs[1]), thunks)

src/array/setindex.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function stage(ctx::Context, sidx::SetIndex)
3535
local_dmn = ArrayDomain(map(x->x[2], idx_and_dmn))
3636
s = subdmns[idx...]
3737
part_to_set = sidx.val
38-
ps[idx...] = Thunk(ps[idx...]) do p
38+
ps[idx...] = Thunk(nothing=>ps[idx...]) do p
3939
q = copy(p)
4040
q[indexes(project(s, local_dmn))...] .= part_to_set
4141
q

src/compute.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ function dependents(node::Thunk)
9797
if !haskey(deps, next)
9898
deps[next] = Set{Thunk}()
9999
end
100-
for inp in inputs(next)
100+
for (_, inp) in next.inputs
101101
if istask(inp) || (inp isa Chunk)
102102
s = get!(()->Set{Thunk}(), deps, inp)
103103
push!(s, next)
@@ -165,7 +165,7 @@ function order(node::Thunk, ndeps)
165165
haskey(output, next) && continue
166166
s += 1
167167
output[next] = s
168-
parents = filter(istask, inputs(next))
168+
parents = filter(istask, map(last, next.inputs))
169169
if !isempty(parents)
170170
# If parents is empty, sort! should be a no-op, but raises an ambiguity error
171171
# when InlineStrings.jl is loaded (at least, version 1.1.0), because InlineStrings

src/processor.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ function delete_processor_callback!(name::Symbol)
3131
end
3232

3333
"""
34-
execute!(proc::Processor, f, args...) -> Any
34+
execute!(proc::Processor, f, args...; kwargs...) -> Any
3535
36-
Executes the function `f` with arguments `args` on processor `proc`. This
37-
function can be overloaded by `Processor` subtypes to allow executing function
38-
calls differently than normal Julia.
36+
Executes the function `f` with arguments `args` and keyword arguments `kwargs`
37+
on processor `proc`. This function can be overloaded by `Processor` subtypes to
38+
allow executing function calls differently than normal Julia.
3939
"""
4040
function execute! end
4141

@@ -154,12 +154,12 @@ end
154154
iscompatible(proc::ThreadProc, opts, f, args...) = true
155155
iscompatible_func(proc::ThreadProc, opts, f) = true
156156
iscompatible_arg(proc::ThreadProc, opts, x) = true
157-
function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...))
157+
function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @nospecialize(kwargs...))
158158
tls = get_tls()
159159
task = Task() do
160160
set_tls!(tls)
161161
TimespanLogging.prof_task_put!(tls.sch_handle.thunk_id.id)
162-
@invokelatest f(args...)
162+
@invokelatest f(args...; kwargs...)
163163
end
164164
task.sticky = true
165165
ret = ccall(:jl_set_task_tid, Cint, (Any, Cint), task, proc.tid-1)

0 commit comments

Comments
 (0)