|
| 1 | +### |
| 2 | +# Distributed broadcast implementation |
| 3 | +## |
| 4 | + |
| 5 | +using Base.Broadcast |
| 6 | +import Base.Broadcast: BroadcastStyle, Broadcasted, _max |
| 7 | + |
| 8 | +# We define a custom ArrayStyle here since we need to keep track of |
| 9 | +# the fact that it is Distributed and what kind of underlying broadcast behaviour |
| 10 | +# we will encounter. |
| 11 | +struct DArrayStyle{Style <: BroadcastStyle} <: Broadcast.AbstractArrayStyle{Any} end |
| 12 | +DArrayStyle(::S) where {S} = DArrayStyle{S}() |
| 13 | +DArrayStyle(::S, ::Val{N}) where {S,N} = DArrayStyle(S(Val(N))) |
| 14 | +DArrayStyle(::Val{N}) where N = DArrayStyle{Broadcast.DefaultArrayStyle{N}}() |
| 15 | + |
| 16 | +BroadcastStyle(::Type{<:DArray{<:Any, N, A}}) where {N, A} = DArrayStyle(BroadcastStyle(A), Val(N)) |
| 17 | + |
| 18 | +# promotion rules |
| 19 | +function BroadcastStyle(::DArrayStyle{AStyle}, ::DArrayStyle{BStyle}) where {AStyle, BStyle} |
| 20 | + DArrayStyle{BroadcastStyle(AStyle, BStyle)}() |
| 21 | +end |
| 22 | + |
| 23 | +# # deal with one layer deep lazy arrays |
| 24 | +# BroadcastStyle(::Type{<:LinearAlgebra.Transpose{<:Any,T}}) where T <: DArray = BroadcastStyle(T) |
| 25 | +# BroadcastStyle(::Type{<:LinearAlgebra.Adjoint{<:Any,T}}) where T <: DArray = BroadcastStyle(T) |
| 26 | +# BroadcastStyle(::Type{<:SubArray{<:Any,<:Any,<:T}}) where T <: DArray = BroadcastStyle(T) |
| 27 | + |
| 28 | +# # This Union is a hack. Ideally Base would have a Transpose <: WrappedArray <: AbstractArray |
| 29 | +# # and we could define our methods in terms of Union{DArray, WrappedArray{<:Any, <:DArray}} |
| 30 | +# const DDestArray = Union{DArray, |
| 31 | +# LinearAlgebra.Transpose{<:Any,<:DArray}, |
| 32 | +# LinearAlgebra.Adjoint{<:Any,<:DArray}, |
| 33 | +# SubArray{<:Any, <:Any, <:DArray}} |
| 34 | +const DDestArray = DArray |
| 35 | + |
| 36 | +# This method is responsible for selection the output type of broadcast |
| 37 | +function Base.similar(bc::Broadcasted{<:DArrayStyle{Style}}, ::Type{ElType}) where {Style, ElType} |
| 38 | + DArray(map(length, axes(bc))) do I |
| 39 | + # create fake Broadcasted for underlying ArrayStyle |
| 40 | + bc′ = Broadcasted{Style}(identity, (), map(length, I)) |
| 41 | + similar(bc′, ElType) |
| 42 | + end |
| 43 | +end |
| 44 | + |
| 45 | +## |
| 46 | +# We purposefully only specialise `copyto!`, |
| 47 | +# Broadcast implementation that defers to the underlying BroadcastStyle. We can't |
| 48 | +# assume that `getindex` is fast, furthermore we can't assume that the distribution of |
| 49 | +# DArray accross workers is equal or that the underlying array type is consistent. |
| 50 | +# |
| 51 | +# Implementation: |
| 52 | +# - first distribute all arguments |
| 53 | +# - Q: How do decide on the cuts |
| 54 | +# - then localise arguments on each node |
| 55 | +## |
| 56 | +@inline function Base.copyto!(dest::DDestArray, bc::Broadcasted) |
| 57 | + axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) |
| 58 | + |
| 59 | + # Distribute Broadcasted |
| 60 | + # This will turn local AbstractArrays into DArrays |
| 61 | + dbc = bcdistribute(bc) |
| 62 | + |
| 63 | + asyncmap(procs(dest)) do p |
| 64 | + remotecall_fetch(p) do |
| 65 | + # get the indices for the localpart |
| 66 | + lpidx = localpartindex(dest) |
| 67 | + @assert lpidx != 0 |
| 68 | + # create a local version of the broadcast, by constructing views |
| 69 | + # Note: creates copies of the argument |
| 70 | + lbc = bclocal(dbc, dest.indices[lpidx]) |
| 71 | + Base.copyto!(localpart(dest), lbc) |
| 72 | + return nothing |
| 73 | + end |
| 74 | + end |
| 75 | + return dest |
| 76 | +end |
| 77 | + |
| 78 | +@inline function Base.copy(bc::Broadcasted{<:DArrayStyle}) |
| 79 | + dbc = bcdistribute(bc) |
| 80 | + # TODO: teach DArray about axes since this is wrong for OffsetArrays |
| 81 | + DArray(map(length, axes(bc))) do I |
| 82 | + lbc = bclocal(dbc, I) |
| 83 | + copy(lbc) |
| 84 | + end |
| 85 | +end |
| 86 | + |
| 87 | +# _bcview creates takes the shapes of a view and the shape of a broadcasted argument, |
| 88 | +# and produces the view over that argument that constitutes part of the broadcast |
| 89 | +# it is in a sense the inverse of _bcs in Base.Broadcast |
| 90 | +_bcview(::Tuple{}, ::Tuple{}) = () |
| 91 | +_bcview(::Tuple{}, view::Tuple) = () |
| 92 | +_bcview(shape::Tuple, ::Tuple{}) = (shape[1], _bcview(tail(shape), ())...) |
| 93 | +function _bcview(shape::Tuple, view::Tuple) |
| 94 | + return (_bcview1(shape[1], view[1]), _bcview(tail(shape), tail(view))...) |
| 95 | +end |
| 96 | + |
| 97 | +# _bcview1 handles the logic for a single dimension |
| 98 | +function _bcview1(a, b) |
| 99 | + if a == 1 || a == 1:1 |
| 100 | + return 1:1 |
| 101 | + elseif first(a) <= first(b) <= last(a) && |
| 102 | + first(a) <= last(b) <= last(b) |
| 103 | + return b |
| 104 | + else |
| 105 | + throw(DimensionMismatch("broadcast view could not be constructed")) |
| 106 | + end |
| 107 | +end |
| 108 | + |
| 109 | +# Distribute broadcast |
| 110 | +# TODO: How to decide on cuts |
| 111 | +@inline bcdistribute(bc::Broadcasted{Style}) where Style = Broadcasted{DArrayStyle{Style}}(bc.f, bcdistribute_args(bc.args), bc.axes) |
| 112 | +@inline bcdistribute(bc::Broadcasted{Style}) where Style<:DArrayStyle = Broadcasted{Style}(bc.f, bcdistribute_args(bc.args), bc.axes) |
| 113 | + |
| 114 | +# ask BroadcastStyle to decide if argument is in need of being distributed |
| 115 | +bcdistribute(x::T) where T = _bcdistribute(BroadcastStyle(T), x) |
| 116 | +_bcdistribute(::DArrayStyle, x) = x |
| 117 | +# Don't bother distributing singletons |
| 118 | +_bcdistribute(::Broadcast.AbstractArrayStyle{0}, x) = x |
| 119 | +_bcdistribute(::Broadcast.AbstractArrayStyle, x) = distribute(x) |
| 120 | +_bcdistribute(::Any, x) = x |
| 121 | + |
| 122 | +@inline bcdistribute_args(args::Tuple) = (bcdistribute(args[1]), bcdistribute_args(tail(args))...) |
| 123 | +bcdistribute_args(args::Tuple{Any}) = (bcdistribute(args[1]),) |
| 124 | +bcdistribute_args(args::Tuple{}) = () |
| 125 | + |
| 126 | +# dropping axes here since recomputing is easier |
| 127 | +@inline bclocal(bc::Broadcasted{DArrayStyle{Style}}, idxs) where Style = Broadcasted{Style}(bc.f, bclocal_args(_bcview(axes(bc), idxs), bc.args)) |
| 128 | + |
| 129 | +# bclocal will do a view of the data and the copy it over, except |
| 130 | +# when the shard match precisly (TODO: make sure that the invariant holds more often) |
| 131 | +function bclocal(x::DArray{T, N, AT}, idxs) where {T, N, AT} |
| 132 | + bcidxs = _bcview(axes(x), idxs) |
| 133 | + lpidx = localpartindex(x) |
| 134 | + if lpidx != 0 && all(x.indices[lpidx] .== bcidxs) |
| 135 | + return localpart(x) |
| 136 | + end |
| 137 | + return convert(__type(AT), view(x, bcidxs...)) |
| 138 | +end |
| 139 | +bclocal(x, idxs) = x |
| 140 | + |
| 141 | +@inline bclocal_args(idxs, args::Tuple) = (bclocal(args[1], idxs), bclocal_args(idxs, tail(args))...) |
| 142 | +bclocal_args(idxs, args::Tuple{Any}) = (bclocal(args[1], idxs),) |
| 143 | +bclocal_args(idxs, args::Tuple{}) = () |
| 144 | + |
| 145 | +__type(T) = Base.typename(T).wrapper |
0 commit comments