Skip to content

Commit c7a2c5a

Browse files
traktofonandreasnoack
authored andcommitted
New implementation of broadcast for DArray.
All mapreduce tests passing.
1 parent 83af5b6 commit c7a2c5a

File tree

3 files changed

+35
-35
lines changed

3 files changed

+35
-35
lines changed

src/darray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ function Base.setindex!(a::Array, s::SubDArray,
708708
# partial chunk
709709
@async a[idxs...] =
710710
remotecall_fetch(d.pids[i]) do
711-
view(localpart(d), [K[j]-first(K_c[j])+1 for j=1:length(J)]...)
711+
view(localpart(d), [K[j].-first(K_c[j]).+1 for j=1:length(J)]...)
712712
end
713713
end
714714
end

src/mapreduce.jl

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,42 @@ function Base.map!(f::F, dest::DArray, src::DArray) where {F}
1414
return dest
1515
end
1616

17-
#Base.Broadcast._containertype(::Type{D}) where {D<:DArray} = DArray
18-
19-
Base.BroadcastStyle(::Type{DArray}, ::Type{DArray}) = DArray
20-
Base.BroadcastStyle(::Type{DArray}, ::Type{Array}) = DArray
21-
Base.BroadcastStyle(::Type{DArray}, ct) = DArray
22-
#Base.Broadcast.promote_containertype(::Type{Array}, ::Type{DArray}) = DArray
23-
#Base.Broadcast.promote_containertype(ct, ::Type{DArray}) = DArray
24-
25-
Base.Broadcast.broadcast_indices(::Type{DArray}, A) = indices(A)
26-
Base.Broadcast.broadcast_indices(::Type{DArray}, A::Ref) = ()
27-
28-
# FixMe!
29-
## 1. Support for arbitrary indices including OneTo
30-
## 2. This is as type unstable as it can be. Overhead might not matter too much for DArrays though.
31-
function Base.broadcast(f, ::Type{DArray}, ::Nothing, ::Nothing, As...)
32-
T = Base.Broadcast._broadcast_eltype(f, As...)
33-
shape = Base.Broadcast.broadcast_indices(As...)
34-
iter = Base.CartesianIndices(shape)
35-
D = DArray(map(length, shape)) do I
36-
Base.Broadcast.broadcast_c(f, Array,
37-
map(a -> isa(a, Union{Number,Ref}) ? a :
38-
localtype(a)(a[ntuple(i -> i > ndims(a) ? 1 : (size(a, i) == 1 ? (1:1) : I[i]), length(shape))...]), As)...)
17+
# new broadcasting implementation for julia-0.7
18+
19+
Base.BroadcastStyle(::Type{<:DArray}) = Broadcast.ArrayStyle{DArray}()
20+
Base.BroadcastStyle(::Type{<:DArray}, ::Any) = Broadcast.ArrayStyle{DArray}()
21+
22+
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}}, ::Type{ElType}) where {ElType}
23+
DA = find_darray(bc)
24+
similar(DA, ElType)
25+
end
26+
27+
"`DA = find_darray(As)` returns the first DArray among the arguments."
28+
find_darray(bc::Base.Broadcast.Broadcasted) = find_darray(bc.args)
29+
find_darray(args::Tuple) = find_darray(find_darray(args[1]), Base.tail(args))
30+
find_darray(x) = x
31+
find_darray(a::DArray, rest) = a
32+
find_darray(::Any, rest) = find_darray(rest)
33+
34+
function Base.copyto!(dest::DArray, bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}})
35+
@sync for p in procs(dest)
36+
@async remotecall_fetch(p) do
37+
copyto!(localpart(dest), rewrite_local(bc))
38+
end
3939
end
40-
return D
40+
dest
4141
end
4242

43+
"""
44+
Transform a Broadcasted{Broadcast.ArrayStyle{DArray}} object into an equivalent
45+
Broadcasted{Broadcast.DefaultArrayStyle} object for the localparts.
46+
"""
47+
rewrite_local(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}}) = Broadcast.broadcasted(bc.f, rewrite_local(bc.args)...)
48+
rewrite_local(args::Tuple) = map(rewrite_local, args)
49+
rewrite_local(a::DArray) = localpart(a)
50+
rewrite_local(x) = x
51+
52+
4353
function Base.reduce(f, d::DArray)
4454
results = asyncmap(procs(d)) do p
4555
remotecall_fetch(p, f, d) do (f, d)
@@ -128,17 +138,6 @@ function nnz(A::DArray)
128138
return reduce(+, B)
129139
end
130140

131-
# reduce like
132-
# for (fn, fr) in ((:sum, :+),
133-
# (:prod, :*),
134-
# (:maximum, :max),
135-
# (:minimum, :min),
136-
# (:any, :|),
137-
# (:all, :&))
138-
# @eval (Base.$fn)(d::DArray) = reduce($fr, d)
139-
# @eval (Base.$fn)(f, d::DArray) = mapreduce(f, $fr, d)
140-
# end
141-
142141
function Base.extrema(d::DArray)
143142
r = asyncmap(procs(d)) do p
144143
remotecall_fetch(p) do

test/darray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ check_leaks()
182182

183183
@testset "test map / reduce" begin
184184
D2 = map(x->1, D)
185+
@test D2 isa DArray
185186
@test reduce(+, D2) == 100
186187
close(D2)
187188
end

0 commit comments

Comments
 (0)