@@ -81,7 +81,8 @@ Base.IndexStyle(::Type{<:ArrayOp}) = IndexCartesian()
8181
8282collect (x:: ArrayOp ) = collect (fetch (x))
8383
84- Base. fetch (x:: ArrayOp ) = fetch (cached_stage (Context (global_context ()), x):: DArray )
84+ _to_darray (x:: ArrayOp ) = cached_stage (Context (global_context ()), x):: DArray
85+ Base. fetch (x:: ArrayOp ) = fetch (_to_darray (x))
8586
8687collect (x:: Computation ) = collect (fetch (x))
8788
@@ -97,6 +98,27 @@ function Base.show(io::IO, x::ArrayOp)
9798 show (io, m, x)
9899end
99100
101+ export BlockPartition, Blocks
102+
103+ abstract type AbstractBlocks{N} end
104+
105+ abstract type AbstractMultiBlocks{N}<: AbstractBlocks{N} end
106+
107+ abstract type AbstractSingleBlocks{N}<: AbstractBlocks{N} end
108+
109+ struct Blocks{N} <: AbstractMultiBlocks{N}
110+ blocksize:: NTuple{N, Int}
111+ end
112+
113+ """
114+ Blocks(xs...)
115+
116+ Indicates the size of an array operation, specified as `xs`, whose length
117+ indicates the number of dimensions in the resulting array.
118+ """
119+ Blocks (xs:: Int... ) = Blocks (xs)
120+
121+
100122"""
101123 DArray{T,N,F}(domain, subdomains, chunks, concat)
102124 DArray(T, domain, subdomains, chunks, [concat=cat])
@@ -111,23 +133,35 @@ An N-dimensional distributed array of element type T, with a concatenation funct
111133- `concat::F`: a function of type `F`. `concat(x, y; dims=d)` takes two chunks `x` and `y`
112134 and concatenates them along dimension `d`. `cat` is used by default.
113135"""
114- mutable struct DArray{T,N,F} <: ArrayOp{T, N}
136+ mutable struct DArray{T,N,B <: AbstractBlocks{N} , F} <: ArrayOp{T, N}
115137 domain:: ArrayDomain{N}
116138 subdomains:: AbstractArray{ArrayDomain{N}, N}
117139 chunks:: AbstractArray{Any, N}
140+ partitioning:: B
118141 concat:: F
119- function DArray {T,N,F} (domain, subdomains, chunks, concat:: Function ) where {T,N,F}
120- new {T,N,F} (domain, subdomains, chunks, concat)
142+ function DArray {T,N,B, F} (domain, subdomains, chunks, partitioning :: B , concat:: Function ) where {T,N,B ,F}
143+ new {T,N,B, F} (domain, subdomains, chunks, partitioning , concat)
121144 end
122145end
123146
124147# mainly for backwards-compatibility
125- DArray {T, N} (domain, subdomains, chunks) where {T,N} = DArray (T, domain, subdomains, chunks)
148+ DArray {T, N} (domain, subdomains, chunks, partitioning, concat= cat) where {T,N} =
149+ DArray (T, domain, subdomains, chunks, partitioning, concat)
150+
151+ function DArray (T, domain:: ArrayDomain{N} ,
152+ subdomains:: AbstractArray{ArrayDomain{N}, N} ,
153+ chunks:: AbstractArray{<:Any, N} , partitioning:: B , concat= cat) where {N,B<: AbstractMultiBlocks{N} }
154+ DArray {T,N,B,typeof(concat)} (domain, subdomains, chunks, partitioning, concat)
155+ end
126156
127157function DArray (T, domain:: ArrayDomain{N} ,
128- subdomains:: AbstractArray{ArrayDomain{N}, N} ,
129- chunks:: AbstractArray{<:Any, N} , concat= cat) where N
130- DArray {T, N, typeof(concat)} (domain, subdomains, chunks, concat)
158+ subdomains:: ArrayDomain{N} ,
159+ chunks:: Any , partitioning:: B , concat= cat) where {N,B<: AbstractSingleBlocks{N} }
160+ _subdomains = Array {ArrayDomain{N}, N} (undef, ntuple (i-> 1 , N)... )
161+ _subdomains[1 ] = subdomains
162+ _chunks = Array {Any, N} (undef, ntuple (i-> 1 , N)... )
163+ _chunks[1 ] = chunks
164+ DArray {T,N,B,typeof(concat)} (domain, _subdomains, _chunks, partitioning, concat)
131165end
132166
133167domain (d:: DArray ) = d. domain
@@ -136,9 +170,8 @@ domainchunks(d::DArray) = d.subdomains
136170size (x:: DArray ) = size (domain (x))
137171stage (ctx, c:: DArray ) = c
138172
139- function collect (d:: DArray ; tree= false )
173+ function Base . collect (d:: DArray ; tree= false )
140174 a = fetch (d)
141-
142175 if isempty (d. chunks)
143176 return Array {eltype(d)} (undef, size (d)... )
144177 end
@@ -163,18 +196,18 @@ function Base.isequal(x::ArrayOp, y::ArrayOp)
163196 x === y
164197end
165198
166- function Base. similar (x:: DArray{T,N,F } ) where {T,N,F }
199+ function Base. similar (x:: DArray{T,N} ) where {T,N}
167200 alloc (idx, sz) = Array {T,N} (undef, sz)
168201 thunks = [Dagger. @spawn alloc (i, size (x)) for (i, x) in enumerate (x. subdomains)]
169- return DArray {T,N,F} ( x. domain, x. subdomains, thunks, x. concat)
202+ return DArray (T, x. domain, x. subdomains, thunks, x . partitioning , x. concat)
170203end
171204
172- Base. copy (x:: DArray{T,N,F} ) where {T,N,F} =
173- cached_stage ( Context ( global_context ()), map (identity, x)) :: DArray{T,N,F}
205+ Base. copy (x:: DArray{T,N,B, F} ) where {T,N,B ,F} =
206+ map (identity, x):: DArray{T,N,B ,F}
174207
175208# Because OrdinaryDiffEq uses `Base.promote_op(/, ::DArray, ::Real)`
176- Base.:(/ )(x:: DArray{T,N,F} , y:: U ) where {T<: Real ,U<: Real ,N,F} =
177- (x ./ y):: DArray{Base.promote_op(/, T, U),N,F}
209+ Base.:(/ )(x:: DArray{T,N,B, F} , y:: U ) where {T<: Real ,U<: Real ,N,B ,F} =
210+ (x ./ y):: DArray{Base.promote_op(/, T, U),N,B, F}
178211
179212"""
180213 view(c::DArray, d)
@@ -184,7 +217,7 @@ A `view` of a `DArray` chunk returns a `DArray` of `Thunk`s.
184217function Base. view (c:: DArray , d)
185218 subchunks, subdomains = lookup_parts (chunks (c), domainchunks (c), d)
186219 d1 = alignfirst (d)
187- DArray (eltype (c), d1, subdomains, subchunks)
220+ DArray (eltype (c), d1, subdomains, subchunks, c . partitioning, c . concat )
188221end
189222
190223function group_indices (cumlength, idxs,at= 1 , acc= Any[])
@@ -246,12 +279,13 @@ function Base.fetch(c::DArray{T}) where T
246279 sz = size (thunks)
247280 dmn = domain (c)
248281 dmnchunks = domainchunks (c)
249- fetch (Dagger. spawn (Options (meta= true ), thunks... ) do results...
282+ return fetch (Dagger. spawn (Options (meta= true ), thunks... ) do results...
250283 t = eltype (fetch (results[1 ]))
251- DArray (t, dmn, dmnchunks, reshape (Any[results... ], sz))
284+ DArray (t, dmn, dmnchunks, reshape (Any[results... ], sz),
285+ c. partitioning, c. concat)
252286 end )
253287 else
254- c
288+ return c
255289 end
256290end
257291
@@ -290,31 +324,31 @@ Base.@deprecate_binding ComputedArray DArray
290324
291325export Distribute, distribute
292326
293- struct Distribute{T, N } <: ArrayOp{T, N}
327+ struct Distribute{T,N,B <: AbstractBlocks } <: ArrayOp{T, N}
294328 domainchunks
329+ partitioning:: B
295330 data:: AbstractArray{T,N}
296331end
297332
298333size (x:: Distribute ) = size (domain (x. data))
299334
300- export BlockPartition, Blocks
301-
302- """
303- Blocks(xs...)
304-
305- Indicates the size of an array operation, specified as `xs`, whose length
306- indicates the number of dimensions in the resulting array.
307- """
308- struct Blocks{N}
309- blocksize:: NTuple{N, Int}
310- end
311- Blocks (xs:: Int... ) = Blocks (xs)
312-
313335Base. @deprecate BlockPartition Blocks
314336
315337
316338Distribute (p:: Blocks , data:: AbstractArray ) =
317- Distribute (partition (p, domain (data)), data)
339+ Distribute (partition (p, domain (data)), p, data)
340+
341+ function Distribute (domainchunks:: DomainBlocks{N} , data:: AbstractArray{T,N} ) where {T,N}
342+ p = Blocks (ntuple (i-> first (domainchunks. cumlength[i]), N))
343+ Distribute (domainchunks, p, data)
344+ end
345+
346+ function Distribute (data:: AbstractArray{T,N} ) where {T,N}
347+ nprocs = sum (w-> length (Dagger. get_processors (OSProc (w))),
348+ Distributed. procs ())
349+ p = Blocks (ntuple (i-> max (cld (size (data, i), nprocs), 1 ), N))
350+ return Distribute (partition (p, domain (data)), p, data)
351+ end
318352
319353function stage (ctx:: Context , d:: Distribute )
320354 if isa (d. data, ArrayOp)
@@ -329,6 +363,8 @@ function stage(ctx::Context, d::Distribute)
329363 cs = map (d. domainchunks) do idx
330364 chunks = cached_stage (ctx, x[idx]). chunks
331365 shape = size (chunks)
366+ # TODO : fix hashing
367+ # hash = uhash(idx, Base.hash(Distribute, Base.hash(d.data)))
332368 Dagger. spawn (shape, chunks... ) do shape, parts...
333369 if prod (shape) == 0
334370 return Array {T} (undef, shape)
@@ -339,18 +375,21 @@ function stage(ctx::Context, d::Distribute)
339375 end
340376 end
341377 else
342- cs = map (c -> (Dagger. @spawn identity (d. data[c])), d. domainchunks)
378+ cs = map (d. domainchunks) do c
379+ # TODO : fix hashing
380+ # hash = uhash(c, Base.hash(Distribute, Base.hash(d.data)))
381+ Dagger. @spawn identity (d. data[c])
382+ end
343383 end
344- DArray (
345- eltype (d. data),
346- domain (d. data),
347- d. domainchunks,
348- cs
349- )
384+ return DArray (eltype (d. data),
385+ domain (d. data),
386+ d. domainchunks,
387+ cs,
388+ d. partitioning)
350389end
351390
352- function distribute (x:: AbstractArray , dist)
353- fetch (Distribute (dist, x))
391+ function distribute (x:: AbstractArray , dist:: Blocks )
392+ _to_darray (Distribute (dist, x))
354393end
355394
356395function distribute (x:: AbstractArray{T,N} , n:: NTuple{N} ) where {T,N}
0 commit comments