Skip to content

Commit 0331fdf

Browse files
committed
Fix broadcasting into DistributedArray
1 parent f28343c commit 0331fdf

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/broadcast.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# We define a custom ArrayStyle here since we need to keep track of
66
# the fact that it is Distributed and what kind of underlying broadcast behaviour
77
# we will encounter.
8-
struct DArrayStyle{Style <: BroadcastStyle} <: Broadcast.AbstractArrayStyle{Any} end
8+
struct DArrayStyle{Style <: Union{Nothing,BroadcastStyle}} <: Broadcast.AbstractArrayStyle{Any} end
99
DArrayStyle(::S) where {S} = DArrayStyle{S}()
1010
DArrayStyle(::S, ::Val{N}) where {S,N} = DArrayStyle(S(Val(N)))
1111
DArrayStyle(::Val{N}) where N = DArrayStyle{Broadcast.DefaultArrayStyle{N}}()
@@ -119,7 +119,7 @@ end
119119

120120
# Distribute broadcast
121121
# TODO: How to decide on cuts
122-
@inline bcdistribute(bc::Broadcasted{Style}) where Style = Broadcasted{DArrayStyle{Style}}(bc.f, bcdistribute_args(bc.args), bc.axes)
122+
@inline bcdistribute(bc::Broadcasted{Style}) where Style<:Union{Nothing,BroadcastStyle} = Broadcasted{DArrayStyle{Style}}(bc.f, bcdistribute_args(bc.args), bc.axes)
123123
@inline bcdistribute(bc::Broadcasted{Style}) where Style<:DArrayStyle = Broadcasted{Style}(bc.f, bcdistribute_args(bc.args), bc.axes)
124124

125125
# ask BroadcastStyle to decide if argument is in need of being distributed
@@ -135,7 +135,7 @@ bcdistribute_args(args::Tuple{Any}) = (bcdistribute(args[1]),)
135135
bcdistribute_args(args::Tuple{}) = ()
136136

137137
# dropping axes here since recomputing is easier
138-
@inline bclocal(bc::Broadcasted{DArrayStyle{Style}}, idxs) where Style = Broadcasted{Style}(bc.f, bclocal_args(_bcview(axes(bc), idxs), bc.args))
138+
@inline bclocal(bc::Broadcasted{DArrayStyle{Style}}, idxs) where Style<:Union{Nothing,BroadcastStyle} = Broadcasted{Style}(bc.f, bclocal_args(_bcview(axes(bc), idxs), bc.args))
139139

140140
# bclocal will do a view of the data and the copy it over
141141
# except when the data already is local

0 commit comments

Comments
 (0)