Skip to content

Commit 2ee6c28

Browse files
authored
Merge pull request #164 from JuliaParallel/vc/broadcast
Make broadcast implementation work with chunktype
2 parents 9b6111b + b7d5c17 commit 2ee6c28

File tree

4 files changed

+156
-24
lines changed

4 files changed

+156
-24
lines changed

src/DistributedArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ export close, d_closeall
2727
include("darray.jl")
2828
include("core.jl")
2929
include("serialize.jl")
30+
include("broadcast.jl")
3031
include("mapreduce.jl")
3132
include("linalg.jl")
3233
include("sort.jl")

src/broadcast.jl

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
###
2+
# Distributed broadcast implementation
3+
##
4+
5+
using Base.Broadcast
6+
import Base.Broadcast: BroadcastStyle, Broadcasted, _max
7+
8+
# We define a custom ArrayStyle here since we need to keep track of
9+
# the fact that it is Distributed and what kind of underlying broadcast behaviour
10+
# we will encounter.
11+
struct DArrayStyle{Style <: BroadcastStyle} <: Broadcast.AbstractArrayStyle{Any} end
12+
DArrayStyle(::S) where {S} = DArrayStyle{S}()
13+
DArrayStyle(::S, ::Val{N}) where {S,N} = DArrayStyle(S(Val(N)))
14+
DArrayStyle(::Val{N}) where N = DArrayStyle{Broadcast.DefaultArrayStyle{N}}()
15+
16+
BroadcastStyle(::Type{<:DArray{<:Any, N, A}}) where {N, A} = DArrayStyle(BroadcastStyle(A), Val(N))
17+
18+
# promotion rules
19+
function BroadcastStyle(::DArrayStyle{AStyle}, ::DArrayStyle{BStyle}) where {AStyle, BStyle}
20+
DArrayStyle{BroadcastStyle(AStyle, BStyle)}()
21+
end
22+
23+
# # deal with one layer deep lazy arrays
24+
# BroadcastStyle(::Type{<:LinearAlgebra.Transpose{<:Any,T}}) where T <: DArray = BroadcastStyle(T)
25+
# BroadcastStyle(::Type{<:LinearAlgebra.Adjoint{<:Any,T}}) where T <: DArray = BroadcastStyle(T)
26+
# BroadcastStyle(::Type{<:SubArray{<:Any,<:Any,<:T}}) where T <: DArray = BroadcastStyle(T)
27+
28+
# # This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray
29+
# # and we could define our methods in terms of Union{DArray, WrappedArray{<:Any, <:DArray}}
30+
# const DDestArray = Union{DArray,
31+
# LinearAlgebra.Transpose{<:Any,<:DArray},
32+
# LinearAlgebra.Adjoint{<:Any,<:DArray},
33+
# SubArray{<:Any, <:Any, <:DArray}}
34+
const DDestArray = DArray
35+
36+
# This method is responsible for selection the output type of broadcast
37+
function Base.similar(bc::Broadcasted{<:DArrayStyle{Style}}, ::Type{ElType}) where {Style, ElType}
38+
DArray(map(length, axes(bc))) do I
39+
# create fake Broadcasted for underlying ArrayStyle
40+
bc′ = Broadcasted{Style}(identity, (), map(length, I))
41+
similar(bc′, ElType)
42+
end
43+
end
44+
45+
##
46+
# We purposefully only specialise `copyto!`,
47+
# Broadcast implementation that defers to the underlying BroadcastStyle. We can't
48+
# assume that `getindex` is fast, furthermore we can't assume that the distribution of
49+
# DArray accross workers is equal or that the underlying array type is consistent.
50+
#
51+
# Implementation:
52+
# - first distribute all arguments
53+
# - Q: How do decide on the cuts
54+
# - then localise arguments on each node
55+
##
56+
@inline function Base.copyto!(dest::DDestArray, bc::Broadcasted)
57+
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
58+
59+
# Distribute Broadcasted
60+
# This will turn local AbstractArrays into DArrays
61+
dbc = bcdistribute(bc)
62+
63+
asyncmap(procs(dest)) do p
64+
remotecall_fetch(p) do
65+
# get the indices for the localpart
66+
lpidx = localpartindex(dest)
67+
@assert lpidx != 0
68+
# create a local version of the broadcast, by constructing views
69+
# Note: creates copies of the argument
70+
lbc = bclocal(dbc, dest.indices[lpidx])
71+
Base.copyto!(localpart(dest), lbc)
72+
return nothing
73+
end
74+
end
75+
return dest
76+
end
77+
78+
@inline function Base.copy(bc::Broadcasted{<:DArrayStyle})
79+
dbc = bcdistribute(bc)
80+
# TODO: teach DArray about axes since this is wrong for OffsetArrays
81+
DArray(map(length, axes(bc))) do I
82+
lbc = bclocal(dbc, I)
83+
copy(lbc)
84+
end
85+
end
86+
87+
# _bcview creates takes the shapes of a view and the shape of a broadcasted argument,
88+
# and produces the view over that argument that constitutes part of the broadcast
89+
# it is in a sense the inverse of _bcs in Base.Broadcast
90+
_bcview(::Tuple{}, ::Tuple{}) = ()
91+
_bcview(::Tuple{}, view::Tuple) = ()
92+
_bcview(shape::Tuple, ::Tuple{}) = (shape[1], _bcview(tail(shape), ())...)
93+
function _bcview(shape::Tuple, view::Tuple)
94+
return (_bcview1(shape[1], view[1]), _bcview(tail(shape), tail(view))...)
95+
end
96+
97+
# _bcview1 handles the logic for a single dimension
98+
function _bcview1(a, b)
99+
if a == 1 || a == 1:1
100+
return 1:1
101+
elseif first(a) <= first(b) <= last(a) &&
102+
first(a) <= last(b) <= last(b)
103+
return b
104+
else
105+
throw(DimensionMismatch("broadcast view could not be constructed"))
106+
end
107+
end
108+
109+
# Distribute broadcast
110+
# TODO: How to decide on cuts
111+
@inline bcdistribute(bc::Broadcasted{Style}) where Style = Broadcasted{DArrayStyle{Style}}(bc.f, bcdistribute_args(bc.args), bc.axes)
112+
@inline bcdistribute(bc::Broadcasted{Style}) where Style<:DArrayStyle = Broadcasted{Style}(bc.f, bcdistribute_args(bc.args), bc.axes)
113+
114+
# ask BroadcastStyle to decide if argument is in need of being distributed
115+
bcdistribute(x::T) where T = _bcdistribute(BroadcastStyle(T), x)
116+
_bcdistribute(::DArrayStyle, x) = x
117+
# Don't bother distributing singletons
118+
_bcdistribute(::Broadcast.AbstractArrayStyle{0}, x) = x
119+
_bcdistribute(::Broadcast.AbstractArrayStyle, x) = distribute(x)
120+
_bcdistribute(::Any, x) = x
121+
122+
@inline bcdistribute_args(args::Tuple) = (bcdistribute(args[1]), bcdistribute_args(tail(args))...)
123+
bcdistribute_args(args::Tuple{Any}) = (bcdistribute(args[1]),)
124+
bcdistribute_args(args::Tuple{}) = ()
125+
126+
# dropping axes here since recomputing is easier
127+
@inline bclocal(bc::Broadcasted{DArrayStyle{Style}}, idxs) where Style = Broadcasted{Style}(bc.f, bclocal_args(_bcview(axes(bc), idxs), bc.args))
128+
129+
# bclocal will do a view of the data and the copy it over, except
130+
# when the shard match precisly (TODO: make sure that the invariant holds more often)
131+
function bclocal(x::DArray{T, N, AT}, idxs) where {T, N, AT}
132+
bcidxs = _bcview(axes(x), idxs)
133+
lpidx = localpartindex(x)
134+
if lpidx != 0 && all(x.indices[lpidx] .== bcidxs)
135+
return localpart(x)
136+
end
137+
return convert(__type(AT), view(x, bcidxs...))
138+
end
139+
bclocal(x, idxs) = x
140+
141+
@inline bclocal_args(idxs, args::Tuple) = (bclocal(args[1], idxs), bclocal_args(idxs, tail(args))...)
142+
bclocal_args(idxs, args::Tuple{Any}) = (bclocal(args[1], idxs),)
143+
bclocal_args(idxs, args::Tuple{}) = ()
144+
145+
__type(T) = Base.typename(T).wrapper

src/mapreduce.jl

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,6 @@ function Base.map!(f::F, dest::DArray, src::DArray{<:Any,<:Any,A}) where {F,A}
1515
return dest
1616
end
1717

18-
# new broadcasting implementation for julia-0.7
19-
20-
Base.BroadcastStyle(::Type{<:DArray}) = Broadcast.ArrayStyle{DArray}()
21-
Base.BroadcastStyle(::Type{<:DArray}, ::Any) = Broadcast.ArrayStyle{DArray}()
22-
23-
function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}})
24-
T = Base.Broadcast.combine_eltypes(bc.f, bc.args)
25-
shape = Base.Broadcast.combine_axes(bc.args...)
26-
iter = Base.CartesianIndices(shape)
27-
D = DArray(map(length, shape)) do I
28-
A = map(bc.args) do a
29-
if isa(a, Union{Number,Ref})
30-
return a
31-
else
32-
return localtype(a)(
33-
a[ntuple(i -> i > ndims(a) ? 1 : (size(a, i) == 1 ? (1:1) : I[i]), length(shape))...]
34-
)
35-
end
36-
end
37-
broadcast(bc.f, A...)
38-
end
39-
return D
40-
end
41-
4218
function Base.reduce(f, d::DArray)
4319
results = asyncmap(procs(d)) do p
4420
remotecall_fetch(p) do

test/darray.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,16 @@ check_leaks()
846846
f = 2 .* e
847847
@test Array(f) == 2 .* Array(e)
848848
@test Array(map(x -> sum(x) .+ 2, e)) == map(x -> sum(x) .+ 2, e)
849+
850+
@testset "test nested broadcast" begin
851+
g = a .- m .* sin.(c)
852+
@test Array(g) == Array(a) .- Array(m) .* sin.(Array(c))
853+
end
854+
855+
# @testset "lazy wrapped broadcast" begin
856+
# l = similar(a)
857+
# l[1:10, :] .= view(a, 1:10, : )
858+
# end
849859
d_closeall()
850860
end
851861

0 commit comments

Comments
 (0)