Skip to content

Commit d88461b

Browse files
authored
Add special broadcast routines for + (#122)
* Add special broadcast for + * restrict to Vector * ambiguities * use fillsimilar for more general code, generalise getindex * Update runtests.jl
1 parent 1179e7d commit d88461b

File tree

5 files changed

+137
-74
lines changed

5 files changed

+137
-74
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FillArrays"
22
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
3-
version = "0.9.7"
3+
version = "0.10"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/FillArrays.jl

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,18 @@ convert(::Type{T}, F::T) where T<:Fill = F # ambiguity fix
158158

159159
getindex(F::Fill{<:Any,0}) = getindex_value(F)
160160

161-
Base._unsafe_getindex(::IndexStyle, F::Fill, kj::Vararg{AbstractVector{II},N}) where {II<:Integer,N} =
162-
Fill(getindex_value(F), length.(kj))
161+
function Base._unsafe_getindex(::IndexStyle, F::AbstractFill, I::Vararg{Union{Real, AbstractArray}, N}) where N
162+
shape = Base.index_shape(I...)
163+
fillsimilar(F, shape)
164+
end
163165

164-
function getindex(A::Fill, kr::AbstractVector{Bool})
166+
function getindex(A::AbstractFill, kr::AbstractVector{Bool})
165167
length(A) == length(kr) || throw(DimensionMismatch())
166-
Fill(getindex_value(A), count(kr))
168+
fillsimilar(A, count(kr))
167169
end
168-
function getindex(A::Fill, kr::AbstractArray{Bool})
170+
function getindex(A::AbstractFill, kr::AbstractArray{Bool})
169171
size(A) == size(kr) || throw(DimensionMismatch())
170-
Fill(getindex_value(A), count(kr))
172+
fillsimilar(A, count(kr))
171173
end
172174

173175
sort(a::AbstractFill; kwds...) = a
@@ -202,7 +204,7 @@ end
202204
function fill_reshape(parent, dims::Integer...)
203205
n = length(parent)
204206
prod(dims) == n || _throw_dmrs(n, "size", dims)
205-
Fill(getindex_value(parent), dims...)
207+
fillsimilar(parent, dims...)
206208
end
207209

208210
reshape(parent::AbstractFill, dims::Integer...) = reshape(parent, dims)
@@ -267,26 +269,21 @@ for (Typ, funcs, func) in ((:Zeros, :zeros, :zero), (:Ones, :ones, :one))
267269
copy(F::$Typ) = F
268270

269271
getindex(F::$Typ{T,0}) where T = getindex_value(F)
270-
Base._unsafe_getindex(::IndexStyle, F::$Typ{T}, kj::Vararg{AbstractVector{II},N}) where {T,II<:Integer,N} =
271-
$Typ{T}(length.(kj))
272-
function getindex(A::$Typ{T}, kr::AbstractVector{Bool}) where T
273-
length(A) == length(kr) || throw(DimensionMismatch("lengths must match"))
274-
$Typ{T}(count(kr))
275-
end
276-
function getindex(A::$Typ{T}, kr::AbstractArray{Bool}) where T
277-
size(A) == size(kr) || throw(DimensionMismatch("sizes must match"))
278-
$Typ{T}(count(kr))
279-
end
280-
281-
function fill_reshape(parent::$Typ{T}, dims::Integer...) where T
282-
n = length(parent)
283-
prod(dims) == n || _throw_dmrs(n, "size", dims)
284-
$Typ{T}(dims...)
285-
end
286272
end
287273
end
288274

289275

276+
"""
277+
fillsimilar(a::AbstractFill, axes)
278+
279+
creates a fill object that has the same fill value as `a` but
280+
with the specified axes.
281+
For example, if `a isa Zeros` then so is the returned object.
282+
"""
283+
fillsimilar(a::Ones{T}, axes...) where T = Ones{T}(axes...)
284+
fillsimilar(a::Zeros{T}, axes...) where T = Zeros{T}(axes...)
285+
fillsimilar(a::AbstractFill, axes...) = Fill(getindex_value(a), axes...)
286+
290287

291288
rank(F::Zeros) = 0
292289
rank(F::Ones) = 1

src/fillalgebra.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@ adjoint(a::Zeros{T,2}) where T = Zeros{T}(reverse(a.axes))
1515
transpose(a::Fill{T,2}) where T = Fill{T}(transpose(a.value), reverse(a.axes))
1616
adjoint(a::Fill{T,2}) where T = Fill{T}(adjoint(a.value), reverse(a.axes))
1717

18-
fillsimilar(a::Ones{T}, axes) where T = Ones{T}(axes)
19-
fillsimilar(a::Zeros{T}, axes) where T = Zeros{T}(axes)
20-
fillsimilar(a::AbstractFill, axes) = Fill(getindex_value(a), axes)
21-
2218
permutedims(a::AbstractFill{<:Any,1}) = fillsimilar(a, (1, length(a)))
2319
permutedims(a::AbstractFill{<:Any,2}) = fillsimilar(a, reverse(a.axes))
2420

src/fillbroadcast.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ broadcasted(::DefaultArrayStyle, ::typeof(-), a::Zeros, b::Zeros) = _broadcasted
4444
broadcasted(::DefaultArrayStyle, ::typeof(-), a::Ones, b::Zeros) = _broadcasted_ones(-, a, b)
4545
broadcasted(::DefaultArrayStyle, ::typeof(-), a::Ones, b::Ones) = _broadcasted_zeros(-, a, b)
4646

47+
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::Zeros{<:Any,1}, b::Zeros{<:Any,1}) = _broadcasted_zeros(+, a, b)
48+
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::Ones{<:Any,1}, b::Zeros{<:Any,1}) = _broadcasted_ones(+, a, b)
49+
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::Zeros{<:Any,1}, b::Ones{<:Any,1}) = _broadcasted_ones(+, a, b)
50+
51+
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), a::Zeros{<:Any,1}, b::Zeros{<:Any,1}) = _broadcasted_zeros(-, a, b)
52+
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), a::Ones{<:Any,1}, b::Zeros{<:Any,1}) = _broadcasted_ones(-, a, b)
53+
54+
4755
broadcasted(::DefaultArrayStyle, ::typeof(*), a::Zeros, b::Zeros) = _broadcasted_zeros(*, a, b)
4856

4957
# In following, need to restrict to <: Number as otherwise we cannot infer zero from type
@@ -109,16 +117,36 @@ _range_convert(::Type{AbstractVector{T}}, a::AbstractRange) where T = convert(T,
109117
# end
110118
# end
111119

112-
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::Ones{T}, b::AbstractRange{V}) where {T,V}
120+
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::Ones{T,1}, b::AbstractRange{V}) where {T,V}
113121
broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
114122
return _range_convert(AbstractVector{promote_type(T,V)}, b)
115123
end
116124

117-
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange{V}, b::Ones{T}) where {T,V}
125+
function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange{V}, b::Ones{T,1}) where {T,V}
118126
broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
119127
return _range_convert(AbstractVector{promote_type(T,V)}, a)
120128
end
121129

130+
for op in (:+, -)
131+
@eval begin
132+
function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractVector{T}, b::Zeros{V,1}) where {T,V}
133+
broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
134+
LinearAlgebra.copy_oftype(a, promote_type(T,V))
135+
end
136+
137+
broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractFill{T,1}, b::Zeros{V,1}) where {T,V} =
138+
Base.invoke(broadcasted, Tuple{DefaultArrayStyle, typeof($op), AbstractFill, AbstractFill}, DefaultArrayStyle{1}(), $op, a, b)
139+
end
140+
end
141+
142+
function broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::Zeros{T,1}, b::AbstractVector{V}) where {T,V}
143+
broadcast_shape(axes(a), axes(b))
144+
LinearAlgebra.copy_oftype(b, promote_type(T,V))
145+
end
146+
147+
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::Zeros{V,1}, b::AbstractFill{T,1}) where {T,V} =
148+
Base.invoke(broadcasted, Tuple{DefaultArrayStyle, typeof(+), AbstractFill, AbstractFill}, DefaultArrayStyle{1}(), +, a, b)
149+
122150
# Need to prevent array-valued fills from broadcasting over entry
123151
_broadcast_getindex_value(a::AbstractFill{<:Number}) = getindex_value(a)
124152
_broadcast_getindex_value(a::AbstractFill) = Ref(getindex_value(a))

test/runtests.jl

Lines changed: 86 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,70 @@ import FillArrays: AbstractFill, RectDiagonal, SquareEye
165165
end
166166
end
167167

168+
@testset "indexing" begin
169+
A = Fill(3.0,5)
170+
@test A[1:3] Fill(3.0,3)
171+
@test A[1:3,1:1] Fill(3.0,3,1)
172+
@test_throws BoundsError A[1:3,2]
173+
@test_throws BoundsError A[1:26]
174+
@test A[[true, false, true, false, false]] Fill(3.0, 2)
175+
A = Fill(3.0, 2, 2)
176+
@test A[[true true; true false]] Fill(3.0, 3)
177+
@test_throws DimensionMismatch A[[true, false]]
178+
179+
A = Ones{Int}(5,5)
180+
@test A[1:3] Ones{Int}(3)
181+
@test A[1:3,1:2] Ones{Int}(3,2)
182+
@test A[1:3,2] Ones{Int}(3)
183+
@test_throws BoundsError A[1:26]
184+
A = Ones{Int}(2,2)
185+
@test A[[true false; true false]] Ones{Int}(2)
186+
@test A[[true, false, true, false]] Ones{Int}(2)
187+
@test_throws DimensionMismatch A[[true false false; true false false]]
188+
189+
A = Zeros{Int}(5,5)
190+
@test A[1:3] Zeros{Int}(3)
191+
@test A[1:3,1:2] Zeros{Int}(3,2)
192+
@test A[1:3,2] Zeros{Int}(3)
193+
@test_throws BoundsError A[1:26]
194+
A = Zeros{Int}(2,2)
195+
@test A[[true false; true false]] Zeros{Int}(2)
196+
@test A[[true, false, true, false]] Zeros{Int}(2)
197+
@test_throws DimensionMismatch A[[true false false; true false false]]
198+
199+
@testset "colon" begin
200+
@test Ones(2)[:] Ones(2)[Base.Slice(Base.OneTo(2))] Ones(2)
201+
@test Zeros(2)[:] Zeros(2)[Base.Slice(Base.OneTo(2))] Zeros(2)
202+
@test Fill(3.0,2)[:] Fill(3.0,2)[Base.Slice(Base.OneTo(2))] Fill(3.0,2)
203+
204+
@test Ones(2,2)[:,:] Ones(2,2)[Base.Slice(Base.OneTo(2)),Base.Slice(Base.OneTo(2))] Ones(2,2)
205+
@test Zeros(2,2)[:,:] Zeros(2)[Base.Slice(Base.OneTo(2)),Base.Slice(Base.OneTo(2))] Zeros(2,2)
206+
@test Fill(3.0,2,2)[:,:] Fill(3.0,2,2)[Base.Slice(Base.OneTo(2)),Base.Slice(Base.OneTo(2))] Fill(3.0,2,2)
207+
end
208+
209+
@testset "mixed integer / vector /colon" begin
210+
a = Fill(2.0,5)
211+
z = Zeros(5)
212+
@test a[1:5] a[:] a
213+
@test z[1:5] z[:] z
214+
215+
A = Fill(2.0,5,6)
216+
Z = Zeros(5,6)
217+
@test A[:,1] A[1:5,1] Fill(2.0,5)
218+
@test A[1,:] A[1,1:6] Fill(2.0,6)
219+
@test A[:,:] A[1:5,1:6] A[1:5,:] A[:,1:6] A
220+
@test Z[:,1] Z[1:5,1] Zeros(5)
221+
@test Z[1,:] Z[1,1:6] Zeros(6)
222+
@test Z[:,:] Z[1:5,1:6] Z[1:5,:] Z[:,1:6] Z
223+
224+
A = Fill(2.0,5,6,7)
225+
Z = Zeros(5,6,7)
226+
@test A[:,1,1] A[1:5,1,1] Fill(2.0,5)
227+
@test A[1,:,1] A[1,1:6,1] Fill(2.0,6)
228+
@test A[:,:,:] A[1:5,1:6,1:7] A[1:5,:,1:7] A[:,1:6,1:7] A
229+
end
230+
end
231+
168232
@testset "RectDiagonal" begin
169233
data = 1:3
170234
expected_size = (5, 3)
@@ -542,8 +606,8 @@ end
542606
@testset "range broadcast" begin
543607
rnge = range(-5.0, step=1.0, length=10)
544608
@test broadcast(*, Fill(5.0, 10), rnge) == broadcast(*, 5.0, rnge)
545-
@test broadcast(*, Zeros(10, 10), rnge) == zeros(10, 10)
546-
@test broadcast(*, rnge, Zeros(10, 10)) == zeros(10, 10)
609+
@test broadcast(*, Zeros(10, 10), rnge) Zeros{Float64}(10, 10)
610+
@test broadcast(*, rnge, Zeros(10, 10)) Zeros{Float64}(10, 10)
547611
@test broadcast(*, Ones{Int}(10), rnge) rnge
548612
@test broadcast(*, rnge, Ones{Int}(10)) rnge
549613
@test_throws DimensionMismatch broadcast(*, Fill(5.0, 11), rnge)
@@ -554,6 +618,12 @@ end
554618
deg = 5:5
555619
@test_throws ArgumentError @inferred(broadcast(*, Fill(5.0, 10), deg)) == broadcast(*, fill(5.0,10), deg)
556620
@test_throws ArgumentError @inferred(broadcast(*, deg, Fill(5.0, 10))) == broadcast(*, deg, fill(5.0,10))
621+
622+
@test rnge .+ Zeros(10) rnge .- Zeros(10) Zeros(10) .+ rnge rnge
623+
624+
@test_throws DimensionMismatch rnge .+ Zeros(5)
625+
@test_throws DimensionMismatch rnge .- Zeros(5)
626+
@test_throws DimensionMismatch Zeros(5) .+ rnge
557627
end
558628

559629
@testset "Special Zeros/Ones" begin
@@ -598,6 +668,15 @@ end
598668
@test Zeros(5) ./ Zeros(5) Zeros(5) .\ Zeros(5) Fill(NaN,5)
599669
@test Zeros{Int}(5,6) ./ Zeros{Int}(5) Zeros{Int}(5) .\ Zeros{Int}(5,6) Fill(NaN,5,6)
600670
end
671+
672+
@testset "Addition" begin
673+
@test Zeros{Int}(5) .+ (1:5) (1:5) .+ Zeros{Int}(5) (1:5) .- Zeros{Int}(5) 1:5
674+
@test Zeros{Int}(1) .+ (1:5) (1:5) .+ Zeros{Int}(1) (1:5) .- Zeros{Int}(1) 1:5
675+
@test Zeros(5) .+ (1:5) == (1:5) .+ Zeros(5) == (1:5) .- Zeros(5) == 1:5
676+
@test Zeros{Int}(5) .+ Fill(1,5) Fill(1,5) .+ Zeros{Int}(5) Fill(1,5) .- Zeros{Int}(5) Fill(1,5)
677+
@test_throws DimensionMismatch Zeros{Int}(2) .+ (1:5)
678+
@test_throws DimensionMismatch (1:5) .+ Zeros{Int}(2)
679+
end
601680
end
602681

603682
@testset "support Ref" begin
@@ -625,6 +704,11 @@ end
625704
@test Ones(10) - Ones(10) Zeros(10)
626705
@test Fill(1,10) - Zeros(10) Fill(1.0,10)
627706

707+
@test Zeros(10) .- Zeros(10) Zeros(10)
708+
@test Ones(10) .- Zeros(10) Ones(10)
709+
@test Ones(10) .- Ones(10) Zeros(10)
710+
@test Fill(1,10) .- Zeros(10) Fill(1.0,10)
711+
628712
@test Zeros(10) .- Zeros(1,9) Zeros(10,9)
629713
@test Ones(10) .- Zeros(1,9) Ones(10,9)
630714
@test Ones(10) .- Ones(1,9) Zeros(10,9)
@@ -652,48 +736,6 @@ end
652736
@test map(exp,x) === Fill(exp(2),5,3)
653737
end
654738

655-
@testset "Sub-arrays" begin
656-
A = Fill(3.0,5)
657-
@test A[1:3] Fill(3.0,3)
658-
@test A[1:3,1:1] Fill(3.0,3,1)
659-
@test_broken A[1:3,2] Zeros{Int}(3)
660-
@test_throws BoundsError A[1:26]
661-
@test A[[true, false, true, false, false]] Fill(3.0, 2)
662-
A = Fill(3.0, 2, 2)
663-
@test A[[true true; true false]] Fill(3.0, 3)
664-
@test_throws DimensionMismatch A[[true, false]]
665-
666-
A = Ones{Int}(5,5)
667-
@test A[1:3] Ones{Int}(3)
668-
@test A[1:3,1:2] Ones{Int}(3,2)
669-
@test_broken A[1:3,2] Ones{Int}(3)
670-
@test_throws BoundsError A[1:26]
671-
A = Ones{Int}(2,2)
672-
@test A[[true false; true false]] Ones{Int}(2)
673-
@test A[[true, false, true, false]] Ones{Int}(2)
674-
@test_throws DimensionMismatch A[[true false false; true false false]]
675-
676-
A = Zeros{Int}(5,5)
677-
@test A[1:3] Zeros{Int}(3)
678-
@test A[1:3,1:2] Zeros{Int}(3,2)
679-
@test_broken A[1:3,2] Zeros{Int}(3)
680-
@test_throws BoundsError A[1:26]
681-
A = Zeros{Int}(2,2)
682-
@test A[[true false; true false]] Zeros{Int}(2)
683-
@test A[[true, false, true, false]] Zeros{Int}(2)
684-
@test_throws DimensionMismatch A[[true false false; true false false]]
685-
686-
@testset "colon" begin
687-
@test Ones(2)[:] Ones(2)[Base.Slice(Base.OneTo(2))] Ones(2)
688-
@test Zeros(2)[:] Zeros(2)[Base.Slice(Base.OneTo(2))] Zeros(2)
689-
@test Fill(3.0,2)[:] Fill(3.0,2)[Base.Slice(Base.OneTo(2))] Fill(3.0,2)
690-
691-
@test Ones(2,2)[:,:] Ones(2,2)[Base.Slice(Base.OneTo(2)),Base.Slice(Base.OneTo(2))] Ones(2,2)
692-
@test Zeros(2,2)[:,:] Zeros(2)[Base.Slice(Base.OneTo(2)),Base.Slice(Base.OneTo(2))] Zeros(2,2)
693-
@test Fill(3.0,2,2)[:,:] Fill(3.0,2,2)[Base.Slice(Base.OneTo(2)),Base.Slice(Base.OneTo(2))] Fill(3.0,2,2)
694-
end
695-
end
696-
697739
@testset "Offset indexing" begin
698740
A = Fill(3, (Base.Slice(-1:1),))
699741
@test axes(A) == (Base.Slice(-1:1),)

0 commit comments

Comments
 (0)