Skip to content

Commit 09f70b0

Browse files
fda-tomejpsamaroo
andauthored
DArray: Use eager API (#396)
Co-authored-by: Julian P Samaroo <[email protected]>
1 parent f7c936f commit 09f70b0

File tree

12 files changed

+129
-139
lines changed

12 files changed

+129
-139
lines changed

src/array/alloc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ end
2525

2626
function stage(ctx, a::AllocateArray)
2727
alloc(idx, sz) = a.f(idx, a.eltype, sz)
28-
thunks = [delayed(alloc)(i, size(x)) for (i, x) in enumerate(a.domainchunks)]
28+
thunks = [Dagger.@spawn alloc(i, size(x)) for (i, x) in enumerate(a.domainchunks)]
2929
DArray(a.eltype,a.domain, a.domainchunks, thunks)
3030
end
3131

src/array/darray.jl

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import Base: ==
1+
import Base: ==, fetch
22
using Serialization
33
import Serialization: serialize, deserialize
44

@@ -78,13 +78,14 @@ domain(x::AbstractArray) = ArrayDomain([1:l for l in size(x)])
7878
abstract type ArrayOp{T, N} <: AbstractArray{T, N} end
7979
Base.IndexStyle(::Type{<:ArrayOp}) = IndexCartesian()
8080

81-
compute(ctx, x::ArrayOp; options=nothing) =
82-
compute(ctx, cached_stage(ctx, x)::DArray; options=options)
8381

84-
collect(ctx::Context, x::ArrayOp; options=nothing) =
85-
collect(ctx, compute(ctx, x; options=options); options=options)
82+
collect(x::ArrayOp) = collect(fetch(x))
8683

87-
collect(x::ArrayOp; options=nothing) = collect(Context(global_context()), x; options=options)
84+
Base.fetch(x::ArrayOp) = fetch(cached_stage(Context(global_context()), x)::DArray)
85+
86+
collect(x::Computation) = collect(fetch(x))
87+
88+
Base.fetch(x::Computation) = fetch(cached_stage(Context(global_context()), x))
8889

8990
function Base.show(io::IO, ::MIME"text/plain", x::ArrayOp)
9091
write(io, string(typeof(x)))
@@ -113,7 +114,7 @@ An N-dimensional distributed array of element type T, with a concatenation funct
113114
mutable struct DArray{T,N,F} <: ArrayOp{T, N}
114115
domain::ArrayDomain{N}
115116
subdomains::AbstractArray{ArrayDomain{N}, N}
116-
chunks::AbstractArray{Union{Chunk,Thunk}, N}
117+
chunks::AbstractArray{Any, N}
117118
concat::F
118119
function DArray{T,N,F}(domain, subdomains, chunks, concat::Function) where {T, N,F}
119120
new(domain, subdomains, chunks, concat)
@@ -135,18 +136,18 @@ domainchunks(d::DArray) = d.subdomains
135136
size(x::DArray) = size(domain(x))
136137
stage(ctx, c::DArray) = c
137138

138-
function collect(ctx::Context, d::DArray; tree=false, options=nothing)
139-
a = compute(ctx, d; options=options)
139+
function collect(d::DArray; tree=false)
140+
a = fetch(d)
140141

141142
if isempty(d.chunks)
142143
return Array{eltype(d)}(undef, size(d)...)
143144
end
144145

145146
dimcatfuncs = [(x...) -> d.concat(x..., dims=i) for i in 1:ndims(d)]
146147
if tree
147-
collect(treereduce_nd(delayed.(dimcatfuncs), a.chunks))
148+
collect(fetch(treereduce_nd(map(x -> ((args...,) -> Dagger.@spawn x(args...)) , dimcatfuncs), a.chunks)))
148149
else
149-
treereduce_nd(dimcatfuncs, asyncmap(collect, a.chunks))
150+
treereduce_nd(dimcatfuncs, asyncmap(fetch, a.chunks))
150151
end
151152
end
152153

@@ -209,53 +210,33 @@ _cumsum(x::AbstractArray) = length(x) == 0 ? Int[] : cumsum(x)
209210
function lookup_parts(ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{N}) where N
210211
groups = map(group_indices, subdmns.cumlength, indexes(d))
211212
sz = map(length, groups)
212-
pieces = Array{Union{Chunk,Thunk}}(undef, sz)
213+
pieces = Array{Any}(undef, sz)
213214
for i = CartesianIndices(sz)
214215
idx_and_dmn = map(getindex, groups, i.I)
215216
idx = map(x->x[1], idx_and_dmn)
216217
dmn = ArrayDomain(map(x->x[2], idx_and_dmn))
217-
pieces[i] = delayed(getindex)(ps[idx...], project(subdmns[idx...], dmn))
218+
pieces[i] = Dagger.@spawn getindex(ps[idx...], project(subdmns[idx...], dmn))
218219
end
219220
out_cumlength = map(g->_cumsum(map(x->length(x[2]), g)), groups)
220221
out_dmn = DomainBlocks(ntuple(x->1,Val(N)), out_cumlength)
221222
pieces, out_dmn
222223
end
223224

224-
225225
"""
226-
compute(ctx::Context, x::DArray; persist=true, options=nothing)
227-
228-
A `DArray` object may contain a thunk in it, in which case
229-
we first turn it into a `Thunk` and then compute it.
230-
"""
231-
function compute(ctx::Context, x::DArray; persist=true, options=nothing)
232-
thunk = thunkize(ctx, x, persist=persist)
233-
if isa(thunk, Thunk)
234-
compute(ctx, thunk; options=options)
235-
else
236-
x
237-
end
238-
end
239-
240-
"""
241-
thunkize(ctx::Context, c::DArray; persist=true)
226+
Base.fetch(c::DArray)
242227
243228
If a `DArray` tree has a `Thunk` in it, make the whole thing a big thunk.
244229
"""
245-
function thunkize(ctx::Context, c::DArray; persist=true)
230+
function Base.fetch(c::DArray)
246231
if any(istask, chunks(c))
247232
thunks = chunks(c)
248233
sz = size(thunks)
249234
dmn = domain(c)
250235
dmnchunks = domainchunks(c)
251-
if persist
252-
foreach(persist!, thunks)
253-
end
254-
Thunk(map(thunk->nothing=>thunk, thunks)...; meta=true) do results...
236+
fetch(Dagger.spawn(Options(meta=true), thunks...) do results...
255237
t = eltype(results[1])
256-
DArray(t, dmn, dmnchunks,
257-
reshape(Union{Chunk,Thunk}[results...], sz))
258-
end
238+
DArray(t, dmn, dmnchunks, reshape(Any[results...], sz))
239+
end)
259240
else
260241
c
261242
end
@@ -335,19 +316,18 @@ function stage(ctx::Context, d::Distribute)
335316
cs = map(d.domainchunks) do idx
336317
chunks = cached_stage(ctx, x[idx]).chunks
337318
shape = size(chunks)
338-
(delayed() do shape, parts...
319+
Dagger.spawn(shape, chunks...) do shape, parts...
339320
if prod(shape) == 0
340321
return Array{T}(undef, shape)
341322
end
342323
dimcatfuncs = [(x...) -> concat(x..., dims=i) for i in 1:length(shape)]
343324
ps = reshape(Any[parts...], shape)
344325
collect(treereduce_nd(dimcatfuncs, ps))
345-
end)(shape, chunks...)
326+
end
346327
end
347328
else
348-
cs = map(c -> delayed(identity)(d.data[c]), d.domainchunks)
329+
cs = map(c -> (Dagger.@spawn identity(d.data[c])), d.domainchunks)
349330
end
350-
351331
DArray(
352332
eltype(d.data),
353333
domain(d.data),
@@ -357,7 +337,7 @@ function stage(ctx::Context, d::Distribute)
357337
end
358338

359339
function distribute(x::AbstractArray, dist)
360-
compute(Distribute(dist, x))
340+
fetch(Distribute(dist, x))
361341
end
362342

363343
function distribute(x::AbstractArray{T,N}, n::NTuple{N}) where {T,N}

src/array/getindex.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ end
3434
function stage(ctx::Context, gidx::GetIndexScalar)
3535
inp = cached_stage(ctx, gidx.input)
3636
s = view(inp, ArrayDomain(gidx.idx))
37-
delayed(identity)(collect(s)[1])
37+
Dagger.@spawn identity(collect(s)[1])
3838
end
3939

4040
Base.getindex(c::ArrayOp, idx::ArrayDomain) = GetIndex(c, indexes(idx))
4141
Base.getindex(c::ArrayOp, idx...) = GetIndex(c, idx)
42-
Base.getindex(c::ArrayOp, idx::Integer...) = compute(GetIndexScalar(c, idx))
42+
Base.getindex(c::ArrayOp, idx::Integer...) = fetch(GetIndexScalar(c, idx))

src/array/map-reduce.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function stage(ctx::Context, node::Map)
1919
f = node.f
2020
for i=eachindex(domains)
2121
inps = map(x->chunks(x)[i], inputs)
22-
thunks[i] = Thunk((args...) -> map(f, args...), map(inp->nothing=>inp, inps)...)
22+
thunks[i] = Dagger.@spawn map(f, inps...)
2323
end
2424
DArray(Any, domain(primary), domainchunks(primary), thunks)
2525
end
@@ -40,16 +40,16 @@ end
4040

4141
function stage(ctx::Context, r::ReduceBlock)
4242
inp = stage(ctx, r.input)
43-
reduced_parts = map(x -> Thunk(r.op, nothing=>x; get_result=r.get_result), chunks(inp))
44-
Thunk((xs...) -> r.op_master(xs), map(part->nothing=>part, reduced_parts)...; meta=true)
43+
reduced_parts = map(x -> (Dagger.@spawn get_result=r.get_result r.op(x)), chunks(inp))
44+
r_op_master(args...,) = r.op_master(args)
45+
Dagger.@spawn meta=true r_op_master(reduced_parts...)
4546
end
4647

4748
reduceblock_async(f, x::ArrayOp; get_result=true) = ReduceBlock(f, f, x, get_result)
4849
reduceblock_async(f, g::Function, x::ArrayOp; get_result=true) = ReduceBlock(f, g, x, get_result)
4950

50-
reduceblock(f, x::ArrayOp) = compute(reduceblock_async(f, x))
51-
reduceblock(f, g::Function, x::ArrayOp) =
52-
compute(reduceblock_async(f, g, x))
51+
reduceblock(f, x::ArrayOp) = fetch(reduceblock_async(f, x))
52+
reduceblock(f, g::Function, x::ArrayOp) = fetch(reduceblock_async(f, g, x))
5353

5454
reduce_async(f::Function, x::ArrayOp) = reduceblock_async(xs->reduce(f,xs), xs->reduce(f,xs), x)
5555

@@ -62,7 +62,7 @@ prod(f::Function, x::ArrayOp) = reduceblock(a->prod(f, a), prod, x)
6262

6363
mean(x::ArrayOp) = reduceblock(mean, mean, x)
6464

65-
mapreduce(f::Function, g::Function, x::ArrayOp) = reduce(g, map(f, x))
65+
mapreduce(f::Function, g::Function, x::ArrayOp) = reduce(g, map(f, x)) #think about fetching
6666

6767
function mapreducebykey_seq(f, op, itr, dict=Dict())
6868
for x in itr
@@ -115,7 +115,7 @@ end
115115

116116
function reduce(f::Function, x::ArrayOp; dims = nothing)
117117
if dims === nothing
118-
return compute(reduce_async(f,x))
118+
return fetch(reduce_async(f,x))
119119
elseif dims isa Int
120120
dims = (dims,)
121121
end
@@ -126,10 +126,10 @@ function stage(ctx::Context, r::Reducedim)
126126
inp = cached_stage(ctx, r.input)
127127
thunks = let op = r.op, dims=r.dims
128128
# do reducedim on each block
129-
tmp = map(p->Thunk(b->reduce(op,b,dims=dims), nothing=>p), chunks(inp))
129+
tmp = map(p->Dagger.spawn(b->reduce(op,b,dims=dims), p), chunks(inp))
130130
# combine the results in tree fashion
131131
treereducedim(tmp, r.dims) do x,y
132-
Thunk(op, nothing=>x, nothing=>y)
132+
Dagger.@spawn op(x,y)
133133
end
134134
end
135135
c = domainchunks(inp)

src/array/matrix.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ function size(x::Transpose)
1717
end
1818

1919
transpose(x::ArrayOp) = Transpose(transpose, x)
20-
transpose(x::Union{Chunk, Thunk}) = Thunk(transpose, nothing=>x)
20+
transpose(x::Union{Chunk, EagerThunk}) = @spawn transpose(x)
2121

2222
adjoint(x::ArrayOp) = Transpose(adjoint, x)
23-
adjoint(x::Union{Chunk, Thunk}) = Thunk(adjoint, nothing=>x)
23+
adjoint(x::Union{Chunk, EagerThunk}) = @spawn adjoint(x)
2424

2525
function adjoint(x::ArrayDomain{2})
2626
d = indexes(x)
@@ -32,7 +32,7 @@ function adjoint(x::ArrayDomain{1})
3232
end
3333

3434
function _ctranspose(x::AbstractArray)
35-
Any[delayed(adjoint)(x[j,i]) for i=1:size(x,2), j=1:size(x,1)]
35+
Any[Dagger.@spawn adjoint(x[j,i]) for i=1:size(x,2), j=1:size(x,1)]
3636
end
3737

3838
function stage(ctx::Context, node::Transpose)
@@ -91,8 +91,12 @@ function (+)(a::ArrayDomain, b::ArrayDomain)
9191
a
9292
end
9393

94-
(*)(a::Union{Chunk, Thunk}, b::Union{Chunk, Thunk}) = Thunk(*, nothing=>a, nothing=>b)
95-
(+)(a::Union{Chunk, Thunk}, b::Union{Chunk, Thunk}) = Thunk(+, nothing=>a, nothing=>b)
94+
struct BinaryComputeOp{F} end
95+
BinaryComputeOp{F}(x::Union{Chunk,EagerThunk}, y::Union{Chunk,EagerThunk}) where F = @spawn F(x, y)
96+
BinaryComputeOp{F}(x, y) where F = F(x, y)
97+
98+
const AddComputeOp = BinaryComputeOp{+}
99+
const MulComputeOp = BinaryComputeOp{*}
96100

97101
# we define our own matmat and matvec multiply
98102
# for computing the new domains and thunks.
@@ -101,7 +105,7 @@ function _mul(a::Matrix, b::Matrix; T=eltype(a))
101105
n = size(a, 2)
102106
for i=1:size(a,1)
103107
for j=1:size(b, 2)
104-
c[i,j] = treereduce(+, map(*, reshape(a[i,:], (n,)), b[:, j]))
108+
c[i,j] = treereduce(AddComputeOp, map(MulComputeOp, reshape(a[i,:], (n,)), b[:, j]))
105109
end
106110
end
107111
c
@@ -111,14 +115,14 @@ function _mul(a::Matrix, b::Vector; T=eltype(b))
111115
c = Array{T}(undef, size(a,1))
112116
n = size(a,2)
113117
for i=1:size(a,1)
114-
c[i] = treereduce(+, map(*, reshape(a[i, :], (n,)), b))
118+
c[i] = treereduce(AddComputeOp, map(MulComputeOp, reshape(a[i, :], (n,)), b))
115119
end
116120
c
117121
end
118122

119123
function _mul(a::Vector, b::Vector; T=eltype(b))
120124
@assert length(b) == 1
121-
[x * b[1] for x in a]
125+
[MulComputeOp(x, b[1]) for x in a]
122126
end
123127

124128
function promote_distribution(ctx::Context, m::MatMul, a,b)
@@ -176,7 +180,7 @@ function stage(ctx::Context, mul::MatMul)
176180
a, b = stage_operands(ctx, mul, mul.a, mul.b)
177181
d = domain(a)*domain(b)
178182
DArray(Any, d, domainchunks(a)*domainchunks(b),
179-
_mul(chunks(a), chunks(b); T=Thunk))
183+
_mul(chunks(a), chunks(b); T=Any))
180184
end
181185

182186
Base.power_by_squaring(x::DArray, i::Int) = foldl(*, ntuple(idx->x, i))
@@ -211,7 +215,7 @@ end
211215
function _scale(l, r)
212216
res = similar(r, Any)
213217
for i=1:length(l)
214-
res[i,:] = map(x->Thunk((a,b) -> Diagonal(a)*b, nothing=>l[i], nothing=>x), r[i,:])
218+
res[i,:] = map(x->Dagger.spawn((a,b) -> Diagonal(a)*b, l[i], x), r[i,:])
215219
end
216220
res
217221
end

src/array/operators.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ function stage(ctx::Context, node::BCast)
9090
end
9191
blcks = DomainBlocks(map(_->1, size(node)), cumlengths)
9292

93-
thunks = broadcast(delayed((args...)->broadcast(bc.f, args...); ),
94-
args2...)
93+
thunks = broadcast((args3...)->Dagger.spawn((args...)->broadcast(bc.f, args...), args3...), args2...)
9594
DArray(eltype(node), domain(node), blcks, thunks)
9695
end
9796

@@ -107,7 +106,7 @@ Base.@deprecate mappart(args...) mapchunk(args...)
107106
function stage(ctx::Context, node::MapChunk)
108107
inputs = map(x->cached_stage(ctx, x), node.input)
109108
thunks = map(map(chunks, inputs)...) do ps...
110-
Thunk(node.f, map(p->nothing=>p, ps)...)
109+
Dagger.spawn(node.f, map(p->nothing=>p, ps)...)
111110
end
112111

113112
DArray(Any, domain(inputs[1]), domainchunks(inputs[1]), thunks)

src/array/setindex.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ function stage(ctx::Context, sidx::SetIndex)
2828

2929
groups = map(group_indices, subdmns.cumlength, indexes(d))
3030
sz = map(length, groups)
31-
pieces = Array{Union{Chunk, Thunk}}(undef, sz)
31+
pieces = Array{Any}(undef, sz)
3232
for i = CartesianIndices(sz)
3333
idx_and_dmn = map(getindex, groups, i.I)
3434
idx = map(x->x[1], idx_and_dmn)
3535
local_dmn = ArrayDomain(map(x->x[2], idx_and_dmn))
3636
s = subdmns[idx...]
3737
part_to_set = sidx.val
38-
ps[idx...] = Thunk(nothing=>ps[idx...]) do p
38+
ps[idx...] = Dagger.spawn(ps[idx...]) do p
3939
q = copy(p)
4040
q[indexes(project(s, local_dmn))...] .= part_to_set
4141
q

0 commit comments

Comments
 (0)