Skip to content

Commit 9443ffb

Browse files
committed
DArray: Make allocations dispatchable
1 parent ee6b0ce commit 9443ffb

File tree

2 files changed

+24
-15
lines changed

2 files changed

+24
-15
lines changed

src/array/alloc.jl

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ export partition
44

55
mutable struct AllocateArray{T,N} <: ArrayOp{T,N}
66
eltype::Type{T}
7-
f::Function
7+
f
8+
want_index::Bool
89
domain::ArrayDomain{N}
910
domainchunks
1011
partitioning::AbstractBlocks
@@ -23,17 +24,29 @@ function partition(p::AbstractBlocks, dom::ArrayDomain)
2324
map(_cumlength, map(length, indexes(dom)), p.blocksize))
2425
end
2526

27+
function allocate_array(f, T, idx, sz)
28+
new_f = allocate_array_func(thunk_processor(), f)
29+
return new_f(idx, T, sz)
30+
end
31+
function allocate_array(f, T, sz)
32+
new_f = allocate_array_func(thunk_processor(), f)
33+
return new_f(T, sz)
34+
end
35+
allocate_array_func(::Processor, f) = f
2636
function stage(ctx, a::AllocateArray)
27-
alloc(idx, sz) = a.f(idx, a.eltype, sz)
28-
thunks = [Dagger.@spawn alloc(i, size(x)) for (i, x) in enumerate(a.domainchunks)]
37+
if a.want_index
38+
thunks = [Dagger.@spawn allocate_array(a.f, a.eltype, i, size(x)) for (i, x) in enumerate(a.domainchunks)]
39+
else
40+
thunks = [Dagger.@spawn allocate_array(a.f, a.eltype, size(x)) for (i, x) in enumerate(a.domainchunks)]
41+
end
2942
return DArray(a.eltype, a.domain, a.domainchunks, thunks, a.partitioning)
3043
end
3144

3245
const BlocksOrAuto = Union{Blocks{N} where N, AutoBlocks}
3346

3447
function Base.rand(p::Blocks, eltype::Type, dims::Dims)
3548
d = ArrayDomain(map(x->1:x, dims))
36-
a = AllocateArray(eltype, (_, x...) -> rand(x...), d, partition(p, d), p)
49+
a = AllocateArray(eltype, rand, false, d, partition(p, d), p)
3750
return _to_darray(a)
3851
end
3952
Base.rand(p::BlocksOrAuto, T::Type, dims::Integer...) = rand(p, T, dims)
@@ -45,7 +58,7 @@ Base.rand(::AutoBlocks, eltype::Type, dims::Dims) =
4558

4659
function Base.randn(p::Blocks, eltype::Type, dims::Dims)
4760
d = ArrayDomain(map(x->1:x, dims))
48-
a = AllocateArray(eltype, (_, x...) -> randn(x...), d, partition(p, d), p)
61+
a = AllocateArray(eltype, randn, false, d, partition(p, d), p)
4962
return _to_darray(a)
5063
end
5164
Base.randn(p::BlocksOrAuto, T::Type, dims::Integer...) = randn(p, T, dims)
@@ -57,7 +70,7 @@ Base.randn(::AutoBlocks, eltype::Type, dims::Dims) =
5770

5871
function sprand(p::Blocks, eltype::Type, dims::Dims, sparsity::AbstractFloat)
5972
d = ArrayDomain(map(x->1:x, dims))
60-
a = AllocateArray(eltype, (_, T, _dims) -> sprand(T, _dims..., sparsity), d, partition(p, d), p)
73+
a = AllocateArray(eltype, (T, _dims) -> sprand(T, _dims..., sparsity), false, d, partition(p, d), p)
6174
return _to_darray(a)
6275
end
6376
sprand(p::BlocksOrAuto, T::Type, dims_and_sparsity::Real...) =
@@ -73,7 +86,7 @@ sprand(::AutoBlocks, eltype::Type, dims::Dims, sparsity::AbstractFloat) =
7386

7487
function Base.ones(p::Blocks, eltype::Type, dims::Dims)
7588
d = ArrayDomain(map(x->1:x, dims))
76-
a = AllocateArray(eltype, (_, x...) -> ones(x...), d, partition(p, d), p)
89+
a = AllocateArray(eltype, ones, false, d, partition(p, d), p)
7790
return _to_darray(a)
7891
end
7992
Base.ones(p::BlocksOrAuto, T::Type, dims::Integer...) = ones(p, T, dims)
@@ -85,7 +98,7 @@ Base.ones(::AutoBlocks, eltype::Type, dims::Dims) =
8598

8699
function Base.zeros(p::Blocks, eltype::Type, dims::Dims)
87100
d = ArrayDomain(map(x->1:x, dims))
88-
a = AllocateArray(eltype, (_, x...) -> zeros(x...), d, partition(p, d), p)
101+
a = AllocateArray(eltype, zeros, false, d, partition(p, d), p)
89102
return _to_darray(a)
90103
end
91104
Base.zeros(p::BlocksOrAuto, T::Type, dims::Integer...) = zeros(p, T, dims)

src/array/darray.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -306,16 +306,12 @@ function Base.isequal(x::ArrayOp, y::ArrayOp)
306306
x === y
307307
end
308308

309-
function Base.similar(x::DArray{T,N}) where {T,N}
310-
alloc(idx, sz) = Array{T,N}(undef, sz)
311-
thunks = [Dagger.@spawn alloc(i, size(x)) for (i, x) in enumerate(x.subdomains)]
312-
return DArray(T, x.domain, x.subdomains, thunks, x.partitioning, x.concat)
313-
end
314-
309+
struct AllocateUndef{S} end
310+
(::AllocateUndef{S})(T, dims::Dims{N}) where {S,N} = Array{S,N}(undef, dims)
315311
function Base.similar(A::DArray{T,N} where T, ::Type{S}, dims::Dims{N}) where {S,N}
316312
d = ArrayDomain(map(x->1:x, dims))
317313
p = A.partitioning
318-
a = AllocateArray(S, (_, _, x...) -> Array{S,N}(undef, x...), d, partition(p, d), p)
314+
a = AllocateArray(S, AllocateUndef{S}(), false, d, partition(p, d), p)
319315
return _to_darray(a)
320316
end
321317

0 commit comments

Comments
 (0)