Skip to content

Commit 65bae59

Browse files
Felipe de Alcântara Toméjpsamaroo
authored andcommitted
DArray: Operations return DArrays, local partition
To enable MPI support in the DArray (which is best implemented as each MPI rank holding only a single local partition), this commit splits AbstractBlock further by multi-partition or single-partition storage schemes, where `Blocks <: AbstractMultiBlocks`, and a future `MPIBlocks <: AbstractSingleBlocks`. Additionally, a DArray ctor is added for when only a single subdomain and chunk/thunk is provided. For easier post-hoc repartitioning of DArrays, we now store the original user-provided partitioning scheme within the DArray, and also add it as a type parameter. This also assists future MPI integration by allowing for operations to dispatch on the partitioning scheme. Finally, this commit also adjusts all DArray operations to return DArrays, so that no lazy operators like MatMul or Map are returned to the user. While it would be nice to be able to work with these operators directly, the array ecosystem has generally settled on propagating arrays of the same or similar types as the inputs to many operations. The operators themselves are still present behind a single materializing call, so it should be possible to expose them again in the future if optimization opportunities become feasible.
1 parent f63514a commit 65bae59

File tree

8 files changed

+197
-129
lines changed

8 files changed

+197
-129
lines changed

src/array/alloc.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mutable struct AllocateArray{T,N} <: ArrayOp{T,N}
88
f::Function
99
domain::ArrayDomain{N}
1010
domainchunks
11+
partitioning::AbstractBlocks
1112
end
1213
size(a::AllocateArray) = size(a.domain)
1314

@@ -18,15 +19,15 @@ function _cumlength(len, step)
1819
cumsum(extra > 0 ? vcat(ps, extra) : ps)
1920
end
2021

21-
function partition(p::Blocks, dom::ArrayDomain)
22+
function partition(p::AbstractBlocks, dom::ArrayDomain)
2223
DomainBlocks(map(first, indexes(dom)),
2324
map(_cumlength, map(length, indexes(dom)), p.blocksize))
2425
end
2526

2627
function stage(ctx, a::AllocateArray)
2728
alloc(idx, sz) = a.f(idx, a.eltype, sz)
2829
thunks = [Dagger.@spawn alloc(i, size(x)) for (i, x) in enumerate(a.domainchunks)]
29-
DArray(a.eltype,a.domain, a.domainchunks, thunks)
30+
return DArray(a.eltype, a.domain, a.domainchunks, thunks, a.partitioning)
3031
end
3132

3233
function Base.rand(p::Blocks, eltype::Type, dims)
@@ -35,7 +36,8 @@ function Base.rand(p::Blocks, eltype::Type, dims)
3536
rand(MersenneTwister(s+idx), x...)
3637
end
3738
d = ArrayDomain(map(x->1:x, dims))
38-
AllocateArray(eltype, f, d, partition(p, d))
39+
a = AllocateArray(eltype, f, d, partition(p, d), p)
40+
return _to_darray(a)
3941
end
4042

4143
Base.rand(p::Blocks, t::Type, dims::Integer...) = rand(p, t, dims)
@@ -48,21 +50,24 @@ function Base.randn(p::Blocks, dims)
4850
randn(MersenneTwister(s+idx), x...)
4951
end
5052
d = ArrayDomain(map(x->1:x, dims))
51-
AllocateArray(Float64, f, d, partition(p, d))
53+
a = AllocateArray(Float64, f, d, partition(p, d), p)
54+
return _to_darray(a)
5255
end
5356
Base.randn(p::Blocks, dims::Integer...) = randn(p, dims)
5457

5558
function Base.ones(p::Blocks, eltype::Type, dims)
5659
d = ArrayDomain(map(x->1:x, dims))
57-
AllocateArray(eltype, (_, x...) -> ones(x...), d, partition(p, d))
60+
a = AllocateArray(eltype, (_, x...) -> ones(x...), d, partition(p, d), p)
61+
return _to_darray(a)
5862
end
5963
Base.ones(p::Blocks, t::Type, dims::Integer...) = ones(p, t, dims)
6064
Base.ones(p::Blocks, dims::Integer...) = ones(p, Float64, dims)
6165
Base.ones(p::Blocks, dims::Tuple) = ones(p, Float64, dims)
6266

6367
function Base.zeros(p::Blocks, eltype::Type, dims)
6468
d = ArrayDomain(map(x->1:x, dims))
65-
AllocateArray(eltype, (_, x...) -> zeros(x...), d, partition(p, d))
69+
a = AllocateArray(eltype, (_, x...) -> zeros(x...), d, partition(p, d), p)
70+
return _to_darray(a)
6671
end
6772
Base.zeros(p::Blocks, t::Type, dims::Integer...) = zeros(p, t, dims)
6873
Base.zeros(p::Blocks, dims::Integer...) = zeros(p, Float64, dims)
@@ -73,7 +78,7 @@ function Base.zero(x::DArray{T,N}) where {T,N}
7378
sd = first(x.subdomains)
7479
part_size = ntuple(i->sd.indexes[i].stop, N)
7580
a = zeros(Blocks(part_size...), T, dims)
76-
return cached_stage(Context(global_context()), a)
81+
return _to_darray(a)
7782
end
7883

7984
function sprand(p::Blocks, m::Integer, n::Integer, sparsity::Real)
@@ -82,13 +87,15 @@ function sprand(p::Blocks, m::Integer, n::Integer, sparsity::Real)
8287
sprand(MersenneTwister(s+idx), sz...,sparsity)
8388
end
8489
d = ArrayDomain((1:m, 1:n))
85-
AllocateArray(Float64, f, d, partition(p, d))
90+
a = AllocateArray(Float64, f, d, partition(p, d), p)
91+
return _to_darray(a)
8692
end
8793

8894
function sprand(p::Blocks, n::Integer, sparsity::Real)
8995
s = rand(UInt)
9096
f = function (idx,t,sz)
9197
sprand(MersenneTwister(s+idx), sz...,sparsity)
9298
end
93-
AllocateArray(Float64, f, d, partition(p, ArrayDomain((1:n,))))
99+
a = AllocateArray(Float64, f, d, partition(p, ArrayDomain((1:n,))), p)
100+
return _to_darray(a)
94101
end

src/array/darray.jl

Lines changed: 83 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ Base.IndexStyle(::Type{<:ArrayOp}) = IndexCartesian()
8181

8282
collect(x::ArrayOp) = collect(fetch(x))
8383

84-
Base.fetch(x::ArrayOp) = fetch(cached_stage(Context(global_context()), x)::DArray)
84+
_to_darray(x::ArrayOp) = cached_stage(Context(global_context()), x)::DArray
85+
Base.fetch(x::ArrayOp) = fetch(_to_darray(x))
8586

8687
collect(x::Computation) = collect(fetch(x))
8788

@@ -97,6 +98,27 @@ function Base.show(io::IO, x::ArrayOp)
9798
show(io, m, x)
9899
end
99100

101+
export BlockPartition, Blocks
102+
103+
abstract type AbstractBlocks{N} end
104+
105+
abstract type AbstractMultiBlocks{N}<:AbstractBlocks{N} end
106+
107+
abstract type AbstractSingleBlocks{N}<:AbstractBlocks{N} end
108+
109+
struct Blocks{N} <: AbstractMultiBlocks{N}
110+
blocksize::NTuple{N, Int}
111+
end
112+
113+
"""
114+
Blocks(xs...)
115+
116+
Indicates the size of an array operation, specified as `xs`, whose length
117+
indicates the number of dimensions in the resulting array.
118+
"""
119+
Blocks(xs::Int...) = Blocks(xs)
120+
121+
100122
"""
101123
DArray{T,N,F}(domain, subdomains, chunks, concat)
102124
DArray(T, domain, subdomains, chunks, [concat=cat])
@@ -111,23 +133,35 @@ An N-dimensional distributed array of element type T, with a concatenation funct
111133
- `concat::F`: a function of type `F`. `concat(x, y; dims=d)` takes two chunks `x` and `y`
112134
and concatenates them along dimension `d`. `cat` is used by default.
113135
"""
114-
mutable struct DArray{T,N,F} <: ArrayOp{T, N}
136+
mutable struct DArray{T,N,B<:AbstractBlocks{N},F} <: ArrayOp{T, N}
115137
domain::ArrayDomain{N}
116138
subdomains::AbstractArray{ArrayDomain{N}, N}
117139
chunks::AbstractArray{Any, N}
140+
partitioning::B
118141
concat::F
119-
function DArray{T,N,F}(domain, subdomains, chunks, concat::Function) where {T,N,F}
120-
new{T,N,F}(domain, subdomains, chunks, concat)
142+
function DArray{T,N,B,F}(domain, subdomains, chunks, partitioning::B, concat::Function) where {T,N,B,F}
143+
new{T,N,B,F}(domain, subdomains, chunks, partitioning, concat)
121144
end
122145
end
123146

124147
# mainly for backwards-compatibility
125-
DArray{T, N}(domain, subdomains, chunks) where {T,N} = DArray(T, domain, subdomains, chunks)
148+
DArray{T, N}(domain, subdomains, chunks, partitioning, concat=cat) where {T,N} =
149+
DArray(T, domain, subdomains, chunks, partitioning, concat)
150+
151+
function DArray(T, domain::ArrayDomain{N},
152+
subdomains::AbstractArray{ArrayDomain{N}, N},
153+
chunks::AbstractArray{<:Any, N}, partitioning::B, concat=cat) where {N,B<:AbstractMultiBlocks{N}}
154+
DArray{T,N,B,typeof(concat)}(domain, subdomains, chunks, partitioning, concat)
155+
end
126156

127157
function DArray(T, domain::ArrayDomain{N},
128-
subdomains::AbstractArray{ArrayDomain{N}, N},
129-
chunks::AbstractArray{<:Any, N}, concat=cat) where N
130-
DArray{T, N, typeof(concat)}(domain, subdomains, chunks, concat)
158+
subdomains::ArrayDomain{N},
159+
chunks::Any, partitioning::B, concat=cat) where {N,B<:AbstractSingleBlocks{N}}
160+
_subdomains = Array{ArrayDomain{N}, N}(undef, ntuple(i->1, N)...)
161+
_subdomains[1] = subdomains
162+
_chunks = Array{Any, N}(undef, ntuple(i->1, N)...)
163+
_chunks[1] = chunks
164+
DArray{T,N,B,typeof(concat)}(domain, _subdomains, _chunks, partitioning, concat)
131165
end
132166

133167
domain(d::DArray) = d.domain
@@ -136,9 +170,8 @@ domainchunks(d::DArray) = d.subdomains
136170
size(x::DArray) = size(domain(x))
137171
stage(ctx, c::DArray) = c
138172

139-
function collect(d::DArray; tree=false)
173+
function Base.collect(d::DArray; tree=false)
140174
a = fetch(d)
141-
142175
if isempty(d.chunks)
143176
return Array{eltype(d)}(undef, size(d)...)
144177
end
@@ -163,18 +196,18 @@ function Base.isequal(x::ArrayOp, y::ArrayOp)
163196
x === y
164197
end
165198

166-
function Base.similar(x::DArray{T,N,F}) where {T,N,F}
199+
function Base.similar(x::DArray{T,N}) where {T,N}
167200
alloc(idx, sz) = Array{T,N}(undef, sz)
168201
thunks = [Dagger.@spawn alloc(i, size(x)) for (i, x) in enumerate(x.subdomains)]
169-
return DArray{T,N,F}(x.domain, x.subdomains, thunks, x.concat)
202+
return DArray(T, x.domain, x.subdomains, thunks, x.partitioning, x.concat)
170203
end
171204

172-
Base.copy(x::DArray{T,N,F}) where {T,N,F} =
173-
cached_stage(Context(global_context()), map(identity, x))::DArray{T,N,F}
205+
Base.copy(x::DArray{T,N,B,F}) where {T,N,B,F} =
206+
map(identity, x)::DArray{T,N,B,F}
174207

175208
# Because OrdinaryDiffEq uses `Base.promote_op(/, ::DArray, ::Real)`
176-
Base.:(/)(x::DArray{T,N,F}, y::U) where {T<:Real,U<:Real,N,F} =
177-
(x ./ y)::DArray{Base.promote_op(/, T, U),N,F}
209+
Base.:(/)(x::DArray{T,N,B,F}, y::U) where {T<:Real,U<:Real,N,B,F} =
210+
(x ./ y)::DArray{Base.promote_op(/, T, U),N,B,F}
178211

179212
"""
180213
view(c::DArray, d)
@@ -184,7 +217,7 @@ A `view` of a `DArray` chunk returns a `DArray` of `Thunk`s.
184217
function Base.view(c::DArray, d)
185218
subchunks, subdomains = lookup_parts(chunks(c), domainchunks(c), d)
186219
d1 = alignfirst(d)
187-
DArray(eltype(c), d1, subdomains, subchunks)
220+
DArray(eltype(c), d1, subdomains, subchunks, c.partitioning, c.concat)
188221
end
189222

190223
function group_indices(cumlength, idxs,at=1, acc=Any[])
@@ -246,12 +279,13 @@ function Base.fetch(c::DArray{T}) where T
246279
sz = size(thunks)
247280
dmn = domain(c)
248281
dmnchunks = domainchunks(c)
249-
fetch(Dagger.spawn(Options(meta=true), thunks...) do results...
282+
return fetch(Dagger.spawn(Options(meta=true), thunks...) do results...
250283
t = eltype(fetch(results[1]))
251-
DArray(t, dmn, dmnchunks, reshape(Any[results...], sz))
284+
DArray(t, dmn, dmnchunks, reshape(Any[results...], sz),
285+
c.partitioning, c.concat)
252286
end)
253287
else
254-
c
288+
return c
255289
end
256290
end
257291

@@ -290,31 +324,31 @@ Base.@deprecate_binding ComputedArray DArray
290324

291325
export Distribute, distribute
292326

293-
struct Distribute{T, N} <: ArrayOp{T, N}
327+
struct Distribute{T,N,B<:AbstractBlocks} <: ArrayOp{T, N}
294328
domainchunks
329+
partitioning::B
295330
data::AbstractArray{T,N}
296331
end
297332

298333
size(x::Distribute) = size(domain(x.data))
299334

300-
export BlockPartition, Blocks
301-
302-
"""
303-
Blocks(xs...)
304-
305-
Indicates the size of an array operation, specified as `xs`, whose length
306-
indicates the number of dimensions in the resulting array.
307-
"""
308-
struct Blocks{N}
309-
blocksize::NTuple{N, Int}
310-
end
311-
Blocks(xs::Int...) = Blocks(xs)
312-
313335
Base.@deprecate BlockPartition Blocks
314336

315337

316338
Distribute(p::Blocks, data::AbstractArray) =
317-
Distribute(partition(p, domain(data)), data)
339+
Distribute(partition(p, domain(data)), p, data)
340+
341+
function Distribute(domainchunks::DomainBlocks{N}, data::AbstractArray{T,N}) where {T,N}
342+
p = Blocks(ntuple(i->first(domainchunks.cumlength[i]), N))
343+
Distribute(domainchunks, p, data)
344+
end
345+
346+
function Distribute(data::AbstractArray{T,N}) where {T,N}
347+
nprocs = sum(w->length(Dagger.get_processors(OSProc(w))),
348+
Distributed.procs())
349+
p = Blocks(ntuple(i->max(cld(size(data, i), nprocs), 1), N))
350+
return Distribute(partition(p, domain(data)), p, data)
351+
end
318352

319353
function stage(ctx::Context, d::Distribute)
320354
if isa(d.data, ArrayOp)
@@ -329,6 +363,8 @@ function stage(ctx::Context, d::Distribute)
329363
cs = map(d.domainchunks) do idx
330364
chunks = cached_stage(ctx, x[idx]).chunks
331365
shape = size(chunks)
366+
# TODO: fix hashing
367+
#hash = uhash(idx, Base.hash(Distribute, Base.hash(d.data)))
332368
Dagger.spawn(shape, chunks...) do shape, parts...
333369
if prod(shape) == 0
334370
return Array{T}(undef, shape)
@@ -339,18 +375,21 @@ function stage(ctx::Context, d::Distribute)
339375
end
340376
end
341377
else
342-
cs = map(c -> (Dagger.@spawn identity(d.data[c])), d.domainchunks)
378+
cs = map(d.domainchunks) do c
379+
# TODO: fix hashing
380+
#hash = uhash(c, Base.hash(Distribute, Base.hash(d.data)))
381+
Dagger.@spawn identity(d.data[c])
382+
end
343383
end
344-
DArray(
345-
eltype(d.data),
346-
domain(d.data),
347-
d.domainchunks,
348-
cs
349-
)
384+
return DArray(eltype(d.data),
385+
domain(d.data),
386+
d.domainchunks,
387+
cs,
388+
d.partitioning)
350389
end
351390

352-
function distribute(x::AbstractArray, dist)
353-
fetch(Distribute(dist, x))
391+
function distribute(x::AbstractArray, dist::Blocks)
392+
_to_darray(Distribute(dist, x))
354393
end
355394

356395
function distribute(x::AbstractArray{T,N}, n::NTuple{N}) where {T,N}

src/array/getindex.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,6 @@ function stage(ctx::Context, gidx::GetIndexScalar)
3737
Dagger.@spawn identity(collect(s)[1])
3838
end
3939

40-
Base.getindex(c::ArrayOp, idx::ArrayDomain) = GetIndex(c, indexes(idx))
41-
Base.getindex(c::ArrayOp, idx...) = GetIndex(c, idx)
40+
Base.getindex(c::ArrayOp, idx::ArrayDomain) = _to_darray(GetIndex(c, indexes(idx)))
41+
Base.getindex(c::ArrayOp, idx...) = _to_darray(GetIndex(c, idx))
4242
Base.getindex(c::ArrayOp, idx::Integer...) = fetch(GetIndexScalar(c, idx))

0 commit comments

Comments
 (0)