Skip to content

Commit 6924f86

Browse files
authored
Add support for block fill arrays via BlockArrays.jl (#99)
* Add support for block fill arrays via BlockArrays.jl * Support Block-Triangular-Fill printing
1 parent 552faa7 commit 6924f86

File tree

4 files changed

+87
-27
lines changed

4 files changed

+87
-27
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FillArrays"
22
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
3-
version = "0.8.9"
3+
version = "0.8.10"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/FillArrays.jl

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
77
copy, vec, setindex!, count, ==, reshape, _throw_dmrs, map, zero
88

99
import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
10-
norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero
10+
norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AbstractTriangular
1111

1212
import Base.Broadcast: broadcasted, DefaultArrayStyle, broadcast_shape
1313

@@ -272,6 +272,11 @@ end
272272
@inline RectDiagonal{T}(A::V, args...) where {T,V} = RectDiagonal{T,V}(A, args...)
273273
@inline RectDiagonal(A::V, args...) where {V} = RectDiagonal{eltype(V),V}(A, args...)
274274

275+
276+
# patch missing overload from Base
277+
axes(rd::Diagonal{<:Any,<:AbstractFill}) = (axes(rd.diag,1),axes(rd.diag,1))
278+
axes(T::AbstractTriangular{<:Any,<:AbstractFill}) = axes(parent(T))
279+
275280
axes(rd::RectDiagonal) = rd.axes
276281
size(rd::RectDiagonal) = length.(rd.axes)
277282

@@ -302,15 +307,23 @@ for f in (:triu, :triu!, :tril, :tril!)
302307
end
303308

304309

310+
Base.replace_in_print_matrix(A::RectDiagonal, i::Integer, j::Integer, s::AbstractString) =
311+
i == j ? s : Base.replace_with_centered_mark(s)
312+
313+
305314
const RectOrDiagonal{T,V,Axes} = Union{RectDiagonal{T,V,Axes}, Diagonal{T,V}}
306315
const SquareEye{T,Axes} = Diagonal{T,Ones{T,1,Tuple{Axes}}}
307316
const Eye{T,Axes} = RectOrDiagonal{T,Ones{T,1,Tuple{Axes}}}
308317

309318
@inline SquareEye{T}(n::Integer) where T = Diagonal(Ones{T}(n))
310319
@inline SquareEye(n::Integer) = Diagonal(Ones(n))
320+
@inline SquareEye{T}(ax::Tuple{AbstractUnitRange{Int}}) where T = Diagonal(Ones{T}(ax))
321+
@inline SquareEye(ax::Tuple{AbstractUnitRange{Int}}) = Diagonal(Ones(ax))
311322

312-
@inline Eye{T}(n::Integer) where T = Diagonal(Ones{T}(n))
313-
@inline Eye(n::Integer) = Diagonal(Ones(n))
323+
@inline Eye{T}(n::Integer) where T = SquareEye{T}(n)
324+
@inline Eye(n::Integer) = SquareEye(n)
325+
@inline Eye{T}(ax::Tuple{AbstractUnitRange{Int}}) where T = SquareEye{T}(ax)
326+
@inline Eye(ax::Tuple{AbstractUnitRange{Int}}) = SquareEye(ax)
314327

315328
# function iterate(iter::Eye, istate = (1, 1))
316329
# (i::Int, j::Int) = istate
@@ -328,9 +341,21 @@ end
328341

329342
Eye(n::Integer, m::Integer) = RectDiagonal(Ones(min(n,m)), n, m)
330343
Eye{T}(n::Integer, m::Integer) where T = RectDiagonal{T}(Ones{T}(min(n,m)), n, m)
344+
function Eye{T}((a,b)::NTuple{2,AbstractUnitRange{Int}}) where T
345+
ab = length(a)  length(b) ? a : b
346+
RectDiagonal{T}(Ones{T}((ab,)), (a,b))
347+
end
348+
function Eye((a,b)::NTuple{2,AbstractUnitRange{Int}})
349+
ab = length(a)  length(b) ? a : b
350+
RectDiagonal(Ones((ab,)), (a,b))
351+
end
352+
353+
331354
@deprecate Eye{T}(sz::Tuple{Vararg{Integer,2}}) where T Eye{T}(sz...)
332355
@deprecate Eye(sz::Tuple{Vararg{Integer,2}}) Eye{Float64}(sz...)
333356

357+
358+
334359
@inline Eye{T}(A::AbstractMatrix) where T = Eye{T}(size(A)...)
335360
@inline Eye(A::AbstractMatrix) = Eye{eltype(A)}(size(A)...)
336361

@@ -506,5 +531,20 @@ include("fillbroadcast.jl")
506531
Base.replace_in_print_matrix(::Zeros, ::Integer, ::Integer, s::AbstractString) =
507532
Base.replace_with_centered_mark(s)
508533

534+
# following support blocked fill array printing via
535+
# BlockArrays.jl
536+
axes_print_matrix_row(_, io, X, A, i, cols, sep) =
537+
Base.invoke(Base.print_matrix_row, Tuple{IO,AbstractVecOrMat,Vector,Integer,AbstractVector,AbstractString},
538+
io, X, A, i, cols, sep)
539+
540+
Base.print_matrix_row(io::IO,
541+
X::Union{AbstractFill{<:Any,1},
542+
AbstractFill{<:Any,2},
543+
Diagonal{<:Any,<:AbstractFill{<:Any,1}},
544+
RectDiagonal,
545+
AbstractTriangular{<:Any,<:AbstractFill{<:Any,2}}
546+
}, A::Vector,
547+
i::Integer, cols::AbstractVector, sep::AbstractString) =
548+
axes_print_matrix_row(axes(X), io, X, A, i, cols, sep)
509549

510550
end # module

src/fillbroadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ map(f::Function, r::AbstractFill) = Fill(f(getindex_value(r)), axes(r))
66
### Unary broadcasting
77

88
function broadcasted(::DefaultArrayStyle{N}, op, r::AbstractFill{T,N}) where {T,N}
9-
return Fill(op(getindex_value(r)), size(r))
9+
return Fill(op(getindex_value(r)), axes(r))
1010
end
1111

1212
broadcasted(::DefaultArrayStyle{N}, ::typeof(conj), r::Zeros{T,N}) where {T,N} = r

test/runtests.jl

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ import FillArrays: AbstractFill, RectDiagonal, SquareEye
9797
@test eltype(Eye(5)) == Float64
9898
@test eltype(Eye(5,6)) == Float64
9999

100+
@test Eye((Base.OneTo(5),)) SquareEye((Base.OneTo(5),)) Eye(5)
101+
@test Eye((Base.OneTo(5),Base.OneTo(6))) Eye(5,6)
102+
100103
for T in (Int, Float64)
101104
E = Eye{T}(5)
102105
M = Matrix{T}(I, 5, 5)
@@ -113,6 +116,9 @@ import FillArrays: AbstractFill, RectDiagonal, SquareEye
113116

114117
@test AbstractArray{Float32}(E) == Eye{Float32}(5)
115118
@test AbstractArray{Float32}(E) == Eye{Float32}(5, 5)
119+
120+
@test Eye{T}(randn(4,5)) Eye{T}(4,5) Eye{T}((Base.OneTo(4),Base.OneTo(5)))
121+
@test Eye{T}((Base.OneTo(5),)) SquareEye{T}((Base.OneTo(5),)) Eye{T}(5)
116122
end
117123

118124
@testset "Bool should change type" begin
@@ -169,6 +175,7 @@ end
169175
@test expected[2, :] == expected_matrix[2, :]
170176
@test expected[5, :] == expected_matrix[5, :]
171177

178+
172179
for Typ in (RectDiagonal, RectDiagonal{Int}, RectDiagonal{Int, UnitRange{Int}})
173180
@test Typ(data) == expected[1:3, 1:3]
174181
@test Typ(data, expected_axes) == expected
@@ -187,6 +194,9 @@ end
187194
@test diag(mut) == [5, 2, 3]
188195
mut[2, 1] = 0
189196
@test_throws ArgumentError mut[2, 1] = 9
197+
198+
D = RectDiagonal([1.,2.], (Base.OneTo(3),Base.OneTo(2)))
199+
@test stringmime("text/plain", D) == "3×2 RectDiagonal{Float64,Array{Float64,1},Tuple{Base.OneTo{$Int},Base.OneTo{$Int}}}:\n 1.0 ⋅ \n ⋅ 2.0\n ⋅ ⋅ "
190200
end
191201

192202
# Check that all pair-wise combinations of + / - elements of As and Bs yield the correct
@@ -538,6 +548,9 @@ end
538548
@test Zeros(5) ./ Fill(5.0, 5) Zeros(5) ./ 5.0 Zeros(5)
539549
@test Ones(5) .\ Zeros(5) 1 .\ Zeros(5) Zeros(5)
540550
@test Fill(5.0, 5) .\ Zeros(5) 5.0 .\ Zeros(5) Zeros(5)
551+
552+
@test conj.(Zeros(5)) Zeros(5)
553+
@test conj.(Zeros{ComplexF64}(5)) Zeros{ComplexF64}(5)
541554
end
542555

543556
@testset "support Ref" begin
@@ -564,6 +577,15 @@ end
564577
@test Ones(10) - Zeros(10) Ones(10)
565578
@test Fill(1,10) - Zeros(10) Fill(1.0,10)
566579
end
580+
581+
@testset "Zero .*" begin
582+
@test Zeros{Int}(10) .* Zeros{Int}(10) Zeros{Int}(10)
583+
@test randn(10) .* Zeros(10) Zeros(10)
584+
@test Zeros(10) .* randn(10) Zeros(10)
585+
@test (1:10) .* Zeros(10) Zeros(10)
586+
@test Zeros(10) .* (1:10) Zeros(10)
587+
@test_throws DimensionMismatch (1:11) .* Zeros(10)
588+
end
567589
end
568590

569591
@testset "map" begin
@@ -649,15 +671,6 @@ end
649671
@test allunique(Ones(0))
650672
end
651673

652-
@testset "Zero .*" begin
653-
@test Zeros{Int}(10) .* Zeros{Int}(10) Zeros{Int}(10)
654-
@test randn(10) .* Zeros(10) Zeros(10)
655-
@test Zeros(10) .* randn(10) Zeros(10)
656-
@test (1:10) .* Zeros(10) Zeros(10)
657-
@test Zeros(10) .* (1:10) Zeros(10)
658-
@test_throws DimensionMismatch (1:11) .* Zeros(10)
659-
end
660-
661674
@testset "iterate" begin
662675
for d in (0, 1, 2, 100)
663676
for T in (Float64, Int)
@@ -967,17 +980,24 @@ end
967980
@test_throws ArgumentError rmul!(x,2.0)
968981
end
969982

970-
@testset "Diagonal{<:Fill}" begin
971-
D = Diagonal(Fill(Fill(0.5,2,2),10))
972-
@test @inferred(D[1,1]) === Fill(0.5,2,2)
973-
@test @inferred(D[1,2]) === Fill(0.0,2,2)
974-
D = Diagonal(Fill(Zeros(2,2),10))
975-
@test @inferred(D[1,1]) === Zeros(2,2)
976-
@test @inferred(D[1,2]) === Zeros(2,2)
977-
978-
D = Diagonal([Zeros(1,1), Zeros(2,2)])
979-
@test @inferred(D[1,1]) === Zeros(1,1)
980-
@test @inferred(D[1,2]) === Zeros(1,2)
981-
982-
@test_throws ArgumentError Diagonal(Fill(Ones(2,2),10))[1,2]
983+
@testset "Modified" begin
984+
@testset "Diagonal{<:Fill}" begin
985+
D = Diagonal(Fill(Fill(0.5,2,2),10))
986+
@test @inferred(D[1,1]) === Fill(0.5,2,2)
987+
@test @inferred(D[1,2]) === Fill(0.0,2,2)
988+
@test axes(D) == (Base.OneTo(10),Base.OneTo(10))
989+
D = Diagonal(Fill(Zeros(2,2),10))
990+
@test @inferred(D[1,1]) === Zeros(2,2)
991+
@test @inferred(D[1,2]) === Zeros(2,2)
992+
D = Diagonal([Zeros(1,1), Zeros(2,2)])
993+
@test @inferred(D[1,1]) === Zeros(1,1)
994+
@test @inferred(D[1,2]) === Zeros(1,2)
995+
996+
@test_throws ArgumentError Diagonal(Fill(Ones(2,2),10))[1,2]
997+
end
998+
@testset "Triangular" begin
999+
U = UpperTriangular(Ones(3,3))
1000+
@test U == UpperTriangular(ones(3,3))
1001+
@test axes(U) == (Base.OneTo(3),Base.OneTo(3))
1002+
end
9831003
end

0 commit comments

Comments
 (0)