Skip to content

Commit ec06e4c

Browse files
committed
Add a BroadcastStyle for AbstractFill
1 parent 295266c commit ec06e4c

File tree

3 files changed

+152
-94
lines changed

3 files changed

+152
-94
lines changed

src/FillArrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
77
any, all, axes, isone, iszero, iterate, unique, allunique, permutedims, inv,
88
copy, vec, setindex!, count, ==, reshape, map, zero,
99
show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat,
10-
parent, similar, issorted, add_sum, accumulate, OneTo, permutedims
10+
parent, similar, issorted, add_sum, accumulate, OneTo, permutedims,
11+
real, imag, conj
1112

1213
import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
1314
dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec,

src/fillbroadcast.jl

Lines changed: 113 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,79 @@ function mapreduce(f, op, A::AbstractFill, B::AbstractFill, Cs::AbstractArray...
7373
end
7474

7575

76-
### Unary broadcasting
76+
## BroadcastStyle
77+
78+
abstract type AbstractFillStyle{N} <: Broadcast.AbstractArrayStyle{N} end
79+
struct FillStyle{N} <: AbstractFillStyle{N} end
80+
struct ZerosStyle{N} <: AbstractFillStyle{N} end
81+
FillStyle{N}(::Val{M}) where {N,M} = FillStyle{M}()
82+
ZerosStyle{N}(::Val{M}) where {N,M} = ZerosStyle{M}()
83+
Broadcast.BroadcastStyle(::Type{<:AbstractFill{<:Any,N}}) where {N} = FillStyle{N}()
84+
Broadcast.BroadcastStyle(::Type{<:AbstractZeros{<:Any,N}}) where {N} = ZerosStyle{N}()
85+
Broadcast.BroadcastStyle(::FillStyle{M}, ::ZerosStyle{N}) where {M,N} = FillStyle{max(M,N)}()
86+
Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{2}) = S
87+
Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{1}) = S
88+
Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{0}) = S
89+
90+
_getindex_value(f::AbstractFill) = getindex_value(f)
91+
_getindex_value(x::Number) = x
92+
_getindex_value(x::Ref) = x[]
93+
function _getindex_value(bc::Broadcast.Broadcasted)
94+
bc.f(map(_getindex_value, bc.args)...)
95+
end
96+
97+
has_static_value(x) = false
98+
has_static_value(x::Union{AbstractZeros, AbstractOnes}) = true
99+
has_static_value(x::Broadcast.Broadcasted) = all(has_static_value, x.args)
77100

78-
function broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}) where {T,N}
79-
return Fill(op(getindex_value(r)), axes(r))
101+
function _iszeros(bc::Broadcast.Broadcasted)
102+
all(has_static_value, bc.args) && _iszero(_getindex_value(bc))
80103
end
104+
# conservative check for zeros. In most cases we can't really compare with zero
105+
_iszero(x::Union{Number, AbstractArray}) = iszero(x)
106+
_iszero(_) = false
81107

82-
broadcasted(::DefaultArrayStyle, ::typeof(+), r::AbstractZeros) = r
83-
broadcasted(::DefaultArrayStyle, ::typeof(-), r::AbstractZeros) = r
84-
broadcasted(::DefaultArrayStyle, ::typeof(+), r::AbstractOnes) = r
108+
function _isones(bc::Broadcast.Broadcasted)
109+
all(has_static_value, bc.args) && _isone(_getindex_value(bc))
110+
end
111+
# conservative check for ones. In most cases we can't really compare with one
112+
_isone(x::Union{Number, AbstractArray}) = isone(x)
113+
_isone(_) = false
114+
115+
_isfill(bc::Broadcast.Broadcasted) = all(_isfill, bc.args)
116+
_isfill(f::AbstractFill) = true
117+
_isfill(f::Number) = true
118+
_isfill(f::Ref) = true
119+
_isfill(::Any) = false
120+
121+
function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{N}}) where {N}
122+
if _iszeros(bc)
123+
return Zeros(typeof(_getindex_value(bc)), axes(bc))
124+
elseif _isones(bc)
125+
return Ones(typeof(_getindex_value(bc)), axes(bc))
126+
elseif _isfill(bc)
127+
return Fill(_getindex_value(bc), axes(bc))
128+
else
129+
# fallback style
130+
S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{N}}
131+
copy(convert(S, bc))
132+
end
133+
end
134+
# make the zero-dimensional case consistent with Base
135+
function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}})
136+
S = Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}
137+
copy(convert(S, bc))
138+
end
85139

86-
broadcasted(::DefaultArrayStyle{N}, ::typeof(conj), r::AbstractZeros{T,N}) where {T,N} = r
87-
broadcasted(::DefaultArrayStyle{N}, ::typeof(conj), r::AbstractOnes{T,N}) where {T,N} = r
88-
broadcasted(::DefaultArrayStyle{N}, ::typeof(real), r::AbstractZeros{T,N}) where {T,N} = Zeros{real(T)}(axes(r))
89-
broadcasted(::DefaultArrayStyle{N}, ::typeof(real), r::AbstractOnes{T,N}) where {T,N} = Ones{real(T)}(axes(r))
90-
broadcasted(::DefaultArrayStyle{N}, ::typeof(imag), r::AbstractZeros{T,N}) where {T,N} = Zeros{real(T)}(axes(r))
91-
broadcasted(::DefaultArrayStyle{N}, ::typeof(imag), r::AbstractOnes{T,N}) where {T,N} = Zeros{real(T)}(axes(r))
140+
# some cases that preserve 0d
141+
function broadcast_preserving_0d(f, As...)
142+
bc = Base.broadcasted(f, As...)
143+
r = copy(bc)
144+
length(axes(bc)) == 0 ? Fill(r) : r
145+
end
146+
for f in (:real, :imag, :conj)
147+
@eval ($f)(A::AbstractFill) = broadcast_preserving_0d($f, A)
148+
end
92149

93150
### Binary broadcasting
94151

@@ -100,12 +157,6 @@ broadcasted_zeros(f, a, b, elt, ax) = Zeros{elt}(ax)
100157
broadcasted_ones(f, a, elt, ax) = Ones{elt}(ax)
101158
broadcasted_ones(f, a, b, elt, ax) = Ones{elt}(ax)
102159

103-
function broadcasted(::DefaultArrayStyle, op, a::AbstractFill, b::AbstractFill)
104-
val = op(getindex_value(a), getindex_value(b))
105-
ax = broadcast_shape(axes(a), axes(b))
106-
return broadcasted_fill(op, a, b, val, ax)
107-
end
108-
109160
function _broadcasted_zeros(f, a, b)
110161
elt = Base.Broadcast.combine_eltypes(f, (a, b))
111162
ax = broadcast_shape(axes(a), axes(b))
@@ -122,57 +173,40 @@ function _broadcasted_nan(f, a, b)
122173
return broadcasted_fill(f, a, b, val, ax)
123174
end
124175

125-
broadcasted(::DefaultArrayStyle, ::typeof(+), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(+, a, b)
126-
broadcasted(::DefaultArrayStyle, ::typeof(+), a::AbstractOnes, b::AbstractZeros) = _broadcasted_ones(+, a, b)
127-
broadcasted(::DefaultArrayStyle, ::typeof(+), a::AbstractZeros, b::AbstractOnes) = _broadcasted_ones(+, a, b)
128-
129-
broadcasted(::DefaultArrayStyle, ::typeof(-), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(-, a, b)
130-
broadcasted(::DefaultArrayStyle, ::typeof(-), a::AbstractOnes, b::AbstractZeros) = _broadcasted_ones(-, a, b)
131-
broadcasted(::DefaultArrayStyle, ::typeof(-), a::AbstractOnes, b::AbstractOnes) = _broadcasted_zeros(-, a, b)
132-
133-
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::AbstractZerosVector, b::AbstractZerosVector) = _broadcasted_zeros(+, a, b)
134-
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::AbstractOnesVector, b::AbstractZerosVector) = _broadcasted_ones(+, a, b)
135-
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::AbstractZerosVector, b::AbstractOnesVector) = _broadcasted_ones(+, a, b)
136-
137-
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), a::AbstractZerosVector, b::AbstractZerosVector) = _broadcasted_zeros(-, a, b)
138-
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), a::AbstractOnesVector, b::AbstractZerosVector) = _broadcasted_ones(-, a, b)
139-
140-
141-
broadcasted(::DefaultArrayStyle, ::typeof(*), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(*, a, b)
142-
143176
# In following, need to restrict to <: Number as otherwise we cannot infer zero from type
144177
# TODO: generalise to things like SVector
145178
for op in (:*, :/)
146179
@eval begin
147-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractOnes) = _broadcasted_zeros($op, a, b)
148-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractFill{<:Number}) = _broadcasted_zeros($op, a, b)
149-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::Number) = _broadcasted_zeros($op, a, b)
150-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractRange) = _broadcasted_zeros($op, a, b)
151-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::AbstractArray{<:Number}) = _broadcasted_zeros($op, a, b)
152-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros, b::Base.Broadcast.Broadcasted) = _broadcasted_zeros($op, a, b)
153-
broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractZeros, b::AbstractRange) = _broadcasted_zeros($op, a, b)
180+
broadcasted(::typeof($op), a::AbstractZeros, b::AbstractFill{<:Number}) = _broadcasted_zeros($op, a, b)
181+
broadcasted(::typeof($op), a::AbstractZeros, b::Number) = _broadcasted_zeros($op, a, b)
182+
broadcasted(::typeof($op), a::AbstractZeros, b::AbstractOnes) = _broadcasted_zeros($op, a, b)
183+
broadcasted(::typeof($op), a::AbstractZeros, b::AbstractRange) = _broadcasted_zeros($op, a, b)
184+
broadcasted(::typeof($op), a::AbstractZeros, b::AbstractArray{<:Number}) = _broadcasted_zeros($op, a, b)
185+
broadcasted(::typeof($op), a::AbstractZeros, b::Base.Broadcast.Broadcasted) = _broadcasted_zeros($op, a, b)
154186
end
155187
end
156188

157189
for op in (:*, :\)
158190
@eval begin
159-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractOnes, b::AbstractZeros) = _broadcasted_zeros($op, a, b)
160-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractFill{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b)
161-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Number, b::AbstractZeros) = _broadcasted_zeros($op, a, b)
162-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractRange, b::AbstractZeros) = _broadcasted_zeros($op, a, b)
163-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractArray{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b)
164-
broadcasted(::DefaultArrayStyle, ::typeof($op), a::Base.Broadcast.Broadcasted, b::AbstractZeros) = _broadcasted_zeros($op, a, b)
165-
broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractRange, b::AbstractZeros) = _broadcasted_zeros($op, a, b)
191+
broadcasted(::typeof($op), a::AbstractOnes, b::AbstractZeros) = _broadcasted_zeros($op, a, b)
192+
broadcasted(::typeof($op), a::AbstractFill{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b)
193+
broadcasted(::typeof($op), a::Number, b::AbstractZeros) = _broadcasted_zeros($op, a, b)
194+
broadcasted(::typeof($op), a::AbstractRange, b::AbstractZeros) = _broadcasted_zeros($op, a, b)
195+
broadcasted(::typeof($op), a::AbstractArray{<:Number}, b::AbstractZeros) = _broadcasted_zeros($op, a, b)
196+
broadcasted(::typeof($op), a::Base.Broadcast.Broadcasted, b::AbstractZeros) = _broadcasted_zeros($op, a, b)
166197
end
167198
end
199+
broadcasted(::typeof(*), a::AbstractZeros, b::AbstractZeros) = _broadcasted_zeros(*, a, b)
200+
broadcasted(::typeof(/), a::AbstractZeros, b::AbstractZeros) = _broadcasted_nan(/, a, b)
201+
broadcasted(::typeof(\), a::AbstractZeros, b::AbstractZeros) = _broadcasted_nan(\, a, b)
168202

169-
for op in (:*, :/, :\)
170-
@eval broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractOnes, b::AbstractOnes) = _broadcasted_ones($op, a, b)
171-
end
203+
# for op in (:*, :/, :\)
204+
# @eval broadcasted(::typeof($op), a::AbstractOnes, b::AbstractOnes) = _broadcasted_ones($op, a, b)
205+
# end
172206

173-
for op in (:/, :\)
174-
@eval broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros{<:Number}, b::AbstractZeros{<:Number}) = _broadcasted_nan($op, a, b)
175-
end
207+
# for op in (:/, :\)
208+
# @eval broadcasted(::DefaultArrayStyle, ::typeof($op), a::AbstractZeros{<:Number}, b::AbstractZeros{<:Number}) = _broadcasted_nan($op, a, b)
209+
# end
176210

177211
# special case due to missing converts for ranges
178212
_range_convert(::Type{AbstractVector{T}}, a::AbstractRange{T}) where T = a
@@ -205,65 +239,60 @@ _range_convert(::Type{AbstractVector{T}}, a::ZerosVector) where T = ZerosVector{
205239
# end
206240
# end
207241

208-
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractOnesVector, b::AbstractRange)
242+
function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractOnes, b::AbstractRange)
209243
broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
210244
TT = typeof(zero(eltype(a)) * zero(eltype(b)))
211245
return _range_convert(AbstractVector{TT}, b)
212246
end
213247

214-
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractOnesVector)
248+
function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractOnes)
215249
broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
216250
TT = typeof(zero(eltype(a)) * zero(eltype(b)))
217251
return _range_convert(AbstractVector{TT}, a)
218252
end
219253

220254
for op in (:+, :-)
221255
@eval begin
222-
function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractVector, b::AbstractZerosVector)
223-
broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
256+
function broadcasted(::typeof($op), a::AbstractVector, b::AbstractZerosVector)
257+
ax = broadcast_shape(axes(a), axes(b))
258+
ax == axes(a) || throw(ArgumentError("cannot broadcast an array with size $(size(a)) with $b"))
224259
TT = typeof($op(zero(eltype(a)), zero(eltype(b))))
225260
# Use `TT ∘ (+)` to fix AD issues with `broadcasted(TT, x)`
226261
eltype(a) === TT ? a : broadcasted(TT (+), a)
227262
end
228-
function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractZerosVector, b::AbstractVector)
229-
broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $a to a Vector first."))
263+
function broadcasted(::typeof($op), a::AbstractZerosVector, b::AbstractVector)
264+
ax = broadcast_shape(axes(a), axes(b))
265+
ax == axes(b) || throw(ArgumentError("cannot broadcast $a with an array with size $(size(b))"))
230266
TT = typeof($op(zero(eltype(a)), zero(eltype(b))))
231267
$op === (+) && eltype(b) === TT ? b : broadcasted(TT ($op), b)
232268
end
233-
234-
broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractFillVector, b::AbstractZerosVector) =
235-
Base.invoke(broadcasted, Tuple{DefaultArrayStyle, typeof($op), AbstractFill, AbstractFill}, DefaultArrayStyle{1}(), $op, a, b)
236-
237-
broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractZerosVector, b::AbstractFillVector) =
238-
Base.invoke(broadcasted, Tuple{DefaultArrayStyle, typeof($op), AbstractFill, AbstractFill}, DefaultArrayStyle{1}(), $op, a, b)
269+
function broadcasted(::typeof($op), a::AbstractZerosVector, b::AbstractZerosVector)
270+
ax = broadcast_shape(axes(a), axes(b))
271+
TT = typeof($op(zero(eltype(a)), zero(eltype(b))))
272+
Zeros(TT, ax)
273+
end
239274
end
240275
end
241276

242277
# Need to prevent array-valued fills from broadcasting over entry
243-
_broadcast_getindex_value(a::AbstractFill{<:Number}) = getindex_value(a)
244-
_broadcast_getindex_value(a::AbstractFill) = Ref(getindex_value(a))
245-
278+
_mayberef(x) = Ref(x)
279+
_mayberef(x::Number) = x
246280

247-
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractFill, b::AbstractRange)
281+
function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractFill, b::AbstractRange)
248282
broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
249-
return broadcasted(*, _broadcast_getindex_value(a), b)
283+
return broadcasted(*, _mayberef(getindex_value(a)), b)
250284
end
251285

252-
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractFill)
286+
function broadcasted(::FillStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractFill)
253287
broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
254-
return broadcasted(*, a, _broadcast_getindex_value(b))
288+
return broadcasted(*, a, _mayberef(getindex_value(b)))
255289
end
256290

257-
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} = broadcasted_fill(op, r, op(getindex_value(r),x), axes(r))
258-
broadcasted(::DefaultArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {T,N} = broadcasted_fill(op, r, op(x, getindex_value(r)), axes(r))
259-
broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = broadcasted_fill(op, r, op(getindex_value(r),x[]), axes(r))
260-
broadcasted(::DefaultArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = broadcasted_fill(op, r, op(x[], getindex_value(r)), axes(r))
261-
262291
# support AbstractFill .^ k
263-
broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractFill{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_fill(op, r, getindex_value(r)^k, axes(r))
264-
broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractOnes{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_ones(op, r, T, axes(r))
265-
broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractZeros{T,N}, ::Base.RefValue{Val{0}}) where {T,N} = broadcasted_ones(op, r, T, axes(r))
266-
broadcasted(::DefaultArrayStyle{N}, op::typeof(Base.literal_pow), ::Base.RefValue{typeof(^)}, r::AbstractZeros{T,N}, ::Base.RefValue{Val{k}}) where {T,N,k} = broadcasted_zeros(op, r, T, axes(r))
292+
broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractFill{T,N}, ::Val{k}) where {T,N,k} = broadcasted_fill(op, r, getindex_value(r)^k, axes(r))
293+
broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractOnes{T,N}, ::Val{k}) where {T,N,k} = broadcasted_ones(op, r, T, axes(r))
294+
broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractZeros{T,N}, ::Val{0}) where {T,N} = broadcasted_ones(op, r, T, axes(r))
295+
broadcasted(op::typeof(Base.literal_pow), ::typeof(^), r::AbstractZeros{T,N}, ::Val{k}) where {T,N,k} = broadcasted_zeros(op, r, T, axes(r))
267296

268297
# supports structured broadcast
269298
if isdefined(LinearAlgebra, :fzero)

0 commit comments

Comments
 (0)