1- import Base: ==
1+ import Base: == , fetch
22using Serialization
33import Serialization: serialize, deserialize
44
@@ -78,13 +78,14 @@ domain(x::AbstractArray) = ArrayDomain([1:l for l in size(x)])
7878abstract type ArrayOp{T, N} <: AbstractArray{T, N} end
7979Base. IndexStyle (:: Type{<:ArrayOp} ) = IndexCartesian ()
8080
81- compute (ctx, x:: ArrayOp ; options= nothing ) =
82- compute (ctx, cached_stage (ctx, x):: DArray ; options= options)
8381
84- collect (ctx:: Context , x:: ArrayOp ; options= nothing ) =
85- collect (ctx, compute (ctx, x; options= options); options= options)
82+ collect (x:: ArrayOp ) = collect (fetch (x))
8683
87- collect (x:: ArrayOp ; options= nothing ) = collect (Context (global_context ()), x; options= options)
84+ Base. fetch (x:: ArrayOp ) = fetch (cached_stage (Context (global_context ()), x):: DArray )
85+
86+ collect (x:: Computation ) = collect (fetch (x))
87+
88+ Base. fetch (x:: Computation ) = fetch (cached_stage (Context (global_context ()), x))
8889
8990function Base. show (io:: IO , :: MIME"text/plain" , x:: ArrayOp )
9091 write (io, string (typeof (x)))
@@ -113,7 +114,7 @@ An N-dimensional distributed array of element type T, with a concatenation funct
113114mutable struct DArray{T,N,F} <: ArrayOp{T, N}
114115 domain:: ArrayDomain{N}
115116 subdomains:: AbstractArray{ArrayDomain{N}, N}
116- chunks:: AbstractArray{Union{Chunk,Thunk} , N}
117+ chunks:: AbstractArray{Any , N}
117118 concat:: F
118119 function DArray {T,N,F} (domain, subdomains, chunks, concat:: Function ) where {T, N,F}
119120 new (domain, subdomains, chunks, concat)
@@ -135,18 +136,18 @@ domainchunks(d::DArray) = d.subdomains
135136size (x:: DArray ) = size (domain (x))
136137stage (ctx, c:: DArray ) = c
137138
138- function collect (ctx :: Context , d:: DArray ; tree= false , options = nothing )
139- a = compute (ctx, d; options = options )
139+ function collect (d:: DArray ; tree= false )
140+ a = fetch (d )
140141
141142 if isempty (d. chunks)
142143 return Array {eltype(d)} (undef, size (d)... )
143144 end
144145
145146 dimcatfuncs = [(x... ) -> d. concat (x... , dims= i) for i in 1 : ndims (d)]
146147 if tree
147- collect (treereduce_nd (delayed .( dimcatfuncs), a. chunks))
148+ collect (fetch ( treereduce_nd (map (x -> ((args ... ,) -> Dagger . @spawn x (args ... )) , dimcatfuncs), a. chunks) ))
148149 else
149- treereduce_nd (dimcatfuncs, asyncmap (collect , a. chunks))
150+ treereduce_nd (dimcatfuncs, asyncmap (fetch , a. chunks))
150151 end
151152end
152153
@@ -209,53 +210,33 @@ _cumsum(x::AbstractArray) = length(x) == 0 ? Int[] : cumsum(x)
209210function lookup_parts (ps:: AbstractArray , subdmns:: DomainBlocks{N} , d:: ArrayDomain{N} ) where N
210211 groups = map (group_indices, subdmns. cumlength, indexes (d))
211212 sz = map (length, groups)
212- pieces = Array {Union{Chunk,Thunk} } (undef, sz)
213+ pieces = Array {Any } (undef, sz)
213214 for i = CartesianIndices (sz)
214215 idx_and_dmn = map (getindex, groups, i. I)
215216 idx = map (x-> x[1 ], idx_and_dmn)
216217 dmn = ArrayDomain (map (x-> x[2 ], idx_and_dmn))
217- pieces[i] = delayed ( getindex) (ps[idx... ], project (subdmns[idx... ], dmn))
218+ pieces[i] = Dagger . @spawn getindex (ps[idx... ], project (subdmns[idx... ], dmn))
218219 end
219220 out_cumlength = map (g-> _cumsum (map (x-> length (x[2 ]), g)), groups)
220221 out_dmn = DomainBlocks (ntuple (x-> 1 ,Val (N)), out_cumlength)
221222 pieces, out_dmn
222223end
223224
224-
225225"""
226- compute(ctx::Context, x::DArray; persist=true, options=nothing)
227-
228- A `DArray` object may contain a thunk in it, in which case
229- we first turn it into a `Thunk` and then compute it.
230- """
231- function compute (ctx:: Context , x:: DArray ; persist= true , options= nothing )
232- thunk = thunkize (ctx, x, persist= persist)
233- if isa (thunk, Thunk)
234- compute (ctx, thunk; options= options)
235- else
236- x
237- end
238- end
239-
240- """
241- thunkize(ctx::Context, c::DArray; persist=true)
226+ Base.fetch(c::DArray)
242227
243228If a `DArray` tree has a `Thunk` in it, make the whole thing a big thunk.
244229"""
245- function thunkize (ctx :: Context , c:: DArray ; persist = true )
230+ function Base . fetch ( c:: DArray )
246231 if any (istask, chunks (c))
247232 thunks = chunks (c)
248233 sz = size (thunks)
249234 dmn = domain (c)
250235 dmnchunks = domainchunks (c)
251- if persist
252- foreach (persist!, thunks)
253- end
254- Thunk (map (thunk-> nothing => thunk, thunks)... ; meta= true ) do results...
236+ fetch (Dagger. spawn (Options (meta= true ), thunks... ) do results...
255237 t = eltype (results[1 ])
256- DArray (t, dmn, dmnchunks,
257- reshape (Union{Chunk,Thunk}[results... ], sz))
258- end
238+ DArray (t, dmn, dmnchunks, reshape (Any[results... ], sz))
239+ end )
259240 else
260241 c
261242 end
@@ -335,19 +316,18 @@ function stage(ctx::Context, d::Distribute)
335316 cs = map (d. domainchunks) do idx
336317 chunks = cached_stage (ctx, x[idx]). chunks
337318 shape = size (chunks)
338- ( delayed ( ) do shape, parts...
319+ Dagger . spawn (shape, chunks ... ) do shape, parts...
339320 if prod (shape) == 0
340321 return Array {T} (undef, shape)
341322 end
342323 dimcatfuncs = [(x... ) -> concat (x... , dims= i) for i in 1 : length (shape)]
343324 ps = reshape (Any[parts... ], shape)
344325 collect (treereduce_nd (dimcatfuncs, ps))
345- end )(shape, chunks ... )
326+ end
346327 end
347328 else
348- cs = map (c -> delayed ( identity) (d. data[c]), d. domainchunks)
329+ cs = map (c -> (Dagger . @spawn identity (d. data[c]) ), d. domainchunks)
349330 end
350-
351331 DArray (
352332 eltype (d. data),
353333 domain (d. data),
@@ -357,7 +337,7 @@ function stage(ctx::Context, d::Distribute)
357337end
358338
359339function distribute (x:: AbstractArray , dist)
360- compute (Distribute (dist, x))
340+ fetch (Distribute (dist, x))
361341end
362342
363343function distribute (x:: AbstractArray{T,N} , n:: NTuple{N} ) where {T,N}
0 commit comments