Skip to content

Commit fe34c63

Browse files
authored
Allow more general *TridiagonalLayout, improve interface (#56)
* Allow more general *TridiagonalLayout, improve interface * tests pass again * Hopefully increase coverage * restore sub/sup/diagonaldata * special mul * Require Julia v1.5 * Drop < v"1.5", fix empty range col/rowsupport for bidiagonal * Update diagonal.jl
1 parent e28021e commit fe34c63

File tree

9 files changed

+89
-54
lines changed

9 files changed

+89
-54
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ jobs:
1010
fail-fast: false
1111
matrix:
1212
version:
13-
- '1.0'
14-
- '1'
13+
- '1.5'
1514
- '^1.6.0-0'
1615
os:
1716
- ubuntu-latest

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.5.4"
4+
version = "0.6.0"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -12,7 +12,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1212
[compat]
1313
Compat = "3.16"
1414
FillArrays = "0.11"
15-
julia = "1"
15+
julia = "1.5"
1616

1717
[extras]
1818
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

src/ArrayLayouts.jl

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,7 @@ import LinearAlgebra.BLAS: BlasFloat, BlasReal, BlasComplex
4141

4242
import FillArrays: AbstractFill, getindex_value, axes_print_matrix_row
4343

44-
if VERSION < v"1.2-"
45-
import Base: has_offset_axes
46-
require_one_based_indexing(A...) = !has_offset_axes(A...) || throw(ArgumentError("offset arrays are not supported but got an array with index other than 1"))
47-
else
48-
import Base: require_one_based_indexing
49-
end
44+
import Base: require_one_based_indexing
5045

5146
export materialize, materialize!, MulAdd, muladd!, Ldiv, Rdiv, Lmul, Rmul, Dot,
5247
lmul, rmul, mul, ldiv, rdiv, mul, MemoryLayout, AbstractStridedLayout,
@@ -208,10 +203,8 @@ copyto!(dest::SubArray{<:Any,2,<:LayoutArray}, src::AdjOrTrans{<:Any,<:LayoutArr
208203
copyto!(dest::SubArray{<:Any,2,<:LayoutMatrix}, src::SubArray{<:Any,2,<:AdjOrTrans{<:Any,<:LayoutArray}}) = _copyto!(dest, src)
209204
copyto!(dest::AbstractMatrix, src::SubArray{<:Any,2,<:AdjOrTrans{<:Any,<:LayoutArray}}) = _copyto!(dest, src)
210205
# ambiguity from sparsematrix.jl
211-
if VERSION v"1.5"
212-
copyto!(dest::LayoutMatrix, src::SparseArrays.AbstractSparseMatrixCSC) = _copyto!(dest, src)
213-
copyto!(dest::SubArray{<:Any,2,<:LayoutMatrix}, src::SparseArrays.AbstractSparseMatrixCSC) = _copyto!(dest, src)
214-
end
206+
copyto!(dest::LayoutMatrix, src::SparseArrays.AbstractSparseMatrixCSC) = _copyto!(dest, src)
207+
copyto!(dest::SubArray{<:Any,2,<:LayoutMatrix}, src::SparseArrays.AbstractSparseMatrixCSC) = _copyto!(dest, src)
215208

216209
# avoid bad copy in Base
217210
Base.map(::typeof(copy), D::Diagonal{<:LayoutArray}) = Diagonal(map(copy, D.diag))

src/diagonal.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ copy(M::Rmul{<:Any,<:DiagonalLayout}) = M.A .* permutedims(diagonaldata(M.B))
2222

2323

2424

25-
2625
# Diagonal multiplication never changes structure
2726
similar(M::Rmul{<:Any,<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.A, T, axes)
2827
# equivalent to rescaling
@@ -45,4 +44,21 @@ copy(M::Ldiv{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout}) = Diagona
4544
copy(M::Rdiv{<:DiagonalLayout,<:DiagonalLayout}) = Diagonal(M.A.diag .* inv.(M.B.diag))
4645
copy(M::Rdiv{<:Any,<:DiagonalLayout}) = M.A .* inv.(permutedims(M.B.diag))
4746
copy(M::Rdiv{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A .* inv(getindex_value(M.B.diag))
48-
copy(M::Rdiv{<:DiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = Diagonal(M.A.diag .* inv(getindex_value(M.B.diag)))
47+
copy(M::Rdiv{<:DiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = Diagonal(M.A.diag .* inv(getindex_value(M.B.diag)))
48+
49+
50+
## bi/tridiagonal copy
51+
copy(M::Rmul{<:BidiagonalLayout,<:DiagonalLayout}) = convert(Bidiagonal, M.A) * M.B
52+
copy(M::Lmul{<:DiagonalLayout,<:BidiagonalLayout}) = M.A * convert(Bidiagonal, M.B)
53+
copy(M::Rmul{<:TridiagonalLayout,<:DiagonalLayout}) = convert(Tridiagonal, M.A) * M.B
54+
copy(M::Lmul{<:DiagonalLayout,<:TridiagonalLayout}) = M.A * convert(Tridiagonal, M.B)
55+
copy(M::Rmul{<:SymTridiagonalLayout,<:DiagonalLayout}) = convert(SymTridiagonal, M.A) * M.B
56+
copy(M::Lmul{<:DiagonalLayout,<:SymTridiagonalLayout}) = M.A * convert(SymTridiagonal, M.B)
57+
58+
59+
copy(M::Rmul{<:BidiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))
60+
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:BidiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B
61+
copy(M::Rmul{<:TridiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))
62+
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:TridiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B
63+
copy(M::Rmul{<:SymTridiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))
64+
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:SymTridiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B

src/ldiv.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,7 @@ __ldiv!(_, F, B) = LinearAlgebra.ldiv!(F, B)
9595
@inline materialize!(M::Ldiv) = _ldiv!(M.A, M.B)
9696
@inline materialize!(M::Rdiv) = ldiv!(M.B', M.A')'
9797
@inline copyto!(dest::AbstractArray, M::Rdiv) = copyto!(dest', Ldiv(M.B', M.A'))'
98-
99-
if VERSION v"1.1-pre"
100-
@inline copyto!(dest::AbstractArray, M::Ldiv) = _ldiv!(dest, M.A, M.B)
101-
else
102-
@inline copyto!(dest::AbstractArray, M::Ldiv) = _ldiv!(dest, M.A, copy(M.B))
103-
end
98+
@inline copyto!(dest::AbstractArray, M::Ldiv) = _ldiv!(dest, M.A, copy(M.B))
10499

105100
const MatLdivVec{styleA, styleB, T, V} = Ldiv{styleA, styleB, <:AbstractMatrix{T}, <:AbstractVector{V}}
106101
const MatLdivMat{styleA, styleB, T, V} = Ldiv{styleA, styleB, <:AbstractMatrix{T}, <:AbstractMatrix{V}}

src/memorylayout.jl

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -542,20 +542,24 @@ abstract type AbstractBandedLayout <: MemoryLayout end
542542
abstract type AbstractTridiagonalLayout <: AbstractBandedLayout end
543543

544544
struct DiagonalLayout{ML} <: AbstractBandedLayout end
545-
struct BidiagonalLayout{ML} <: AbstractBandedLayout end
546-
struct SymTridiagonalLayout{ML} <: AbstractTridiagonalLayout end
547-
struct TridiagonalLayout{ML} <: AbstractTridiagonalLayout end
545+
struct BidiagonalLayout{DV,EV} <: AbstractBandedLayout end
546+
struct SymTridiagonalLayout{DV,EV} <: AbstractTridiagonalLayout end
547+
struct TridiagonalLayout{DL,D,DU} <: AbstractTridiagonalLayout end
548548

549-
bidiagonallayout(_) = BidiagonalLayout{UnknownLayout}()
550-
tridiagonallayout(_) = TridiagonalLayout{UnknownLayout}()
551-
symtridiagonallayout(_) = SymTridiagonalLayout{UnknownLayout}()
549+
bidiagonallayout(dv, ev) = BidiagonalLayout{UnknownLayout,UnknownLayout}()
550+
tridiagonallayout(dl, d, du) = TridiagonalLayout{UnknownLayout,UnknownLayout,UnknownLayout}()
551+
552+
symtridiagonallayout(d, ev) = SymTridiagonalLayout{UnknownLayout,UnknownLayout}()
553+
bidiagonallayout(d) = bidiagonallayout(d, d)
554+
tridiagonallayout(d) = tridiagonallayout(d,d,d)
555+
symtridiagonallayout(d) = symtridiagonallayout(d,d)
552556
diagonallayout(_) = DiagonalLayout{UnknownLayout}()
553557

554558

555-
diagonallayout(lay::Union{AbstractStridedLayout, AbstractFillLayout}) = DiagonalLayout{typeof(lay)}()
556-
bidiagonallayout(lay::Union{AbstractStridedLayout, AbstractFillLayout}) = BidiagonalLayout{typeof(lay)}()
557-
tridiagonallayout(lay::Union{AbstractStridedLayout, AbstractFillLayout}) = TridiagonalLayout{typeof(lay)}()
558-
symtridiagonallayout(lay::Union{AbstractStridedLayout, AbstractFillLayout}) = SymTridiagonalLayout{typeof(lay)}()
559+
diagonallayout(::Lay) where Lay<:Union{AbstractStridedLayout, AbstractFillLayout} = DiagonalLayout{Lay}()
560+
bidiagonallayout(::Lay, ::Lay) where Lay<:Union{AbstractStridedLayout, AbstractFillLayout} = BidiagonalLayout{Lay,Lay}()
561+
tridiagonallayout(::Lay, ::Lay, ::Lay) where Lay<:Union{AbstractStridedLayout, AbstractFillLayout} = TridiagonalLayout{Lay,Lay,Lay}()
562+
symtridiagonallayout(::Lay, ::Lay) where Lay<:Union{AbstractStridedLayout, AbstractFillLayout} = SymTridiagonalLayout{Lay,Lay}()
559563

560564

561565
MemoryLayout(D::Type{Diagonal{T,P}}) where {T,P} = diagonallayout(MemoryLayout(P))
@@ -586,9 +590,11 @@ transposelayout(ml::SymTridiagonalLayout) = ml
586590
transposelayout(ml::TridiagonalLayout) = ml
587591
transposelayout(ml::ConjLayout{DiagonalLayout}) = ml
588592

589-
triangularlayout(::Type{<:TriangularLayout{UPLO,'N'}}, ::TridiagonalLayout{ML}) where {UPLO,ML} = BidiagonalLayout{ML}()
590-
triangularlayout(::Type{<:TriangularLayout{UPLO,'N'}}, ::TridiagonalLayout{FillLayout}) where UPLO = BidiagonalLayout{FillLayout}()
591-
triangularlayout(::Type{<:TriangularLayout{UPLO,'U'}}, ::TridiagonalLayout{FillLayout}) where UPLO = BidiagonalLayout{FillLayout}()
593+
triangularlayout(::Type{<:TriangularLayout{'L','N'}}, ::TridiagonalLayout{DL,D,DU}) where {DL,D,DU} = BidiagonalLayout{D,DL}()
594+
triangularlayout(::Type{<:TriangularLayout{'U','N'}}, ::TridiagonalLayout{DL,D,DU}) where {UPLO,DL,D,DU} = BidiagonalLayout{D,DU}()
595+
triangularlayout(::Type{<:TriangularLayout{'L','N'}}, ::TridiagonalLayout{FillLayout,FillLayout,FillLayout}) = BidiagonalLayout{FillLayout,FillLayout}()
596+
triangularlayout(::Type{<:TriangularLayout{'U','N'}}, ::TridiagonalLayout{FillLayout,FillLayout,FillLayout}) = BidiagonalLayout{FillLayout,FillLayout}()
597+
triangularlayout(::Type{<:TriangularLayout{UPLO,'U'}}, ::TridiagonalLayout{FillLayout,FillLayout,FillLayout}) where UPLO = BidiagonalLayout{FillLayout,FillLayout}()
592598

593599
bidiagonaluplo(::Union{UpperTriangular,UnitUpperTriangular}) = 'U'
594600
bidiagonaluplo(::Union{LowerTriangular,UnitLowerTriangular}) = 'L'
@@ -598,20 +604,20 @@ supdiagonaldata(U::Union{UnitUpperTriangular,UpperTriangular}) = supdiagonaldata
598604
subdiagonaldata(U::Union{UnitLowerTriangular,LowerTriangular}) = subdiagonaldata(triangulardata(U))
599605

600606
adjointlayout(::Type{<:Real}, ml::SymTridiagonalLayout) = ml
601-
adjointlayout(::Type{<:Real}, ml::TridiagonalLayout) = ml
607+
adjointlayout(::Type{<:Real}, ::TridiagonalLayout{DL,D,DU}) where {DL,D,DU} = TridiagonalLayout{DU,D,DL}()
602608
adjointlayout(::Type{<:Real}, ml::BidiagonalLayout) = ml
603609

604-
symmetriclayout(B::BidiagonalLayout{ML}) where ML = SymTridiagonalLayout{ML}()
605-
hermitianlayout(::Type{<:Real}, B::BidiagonalLayout{ML}) where ML = SymTridiagonalLayout{ML}()
610+
symmetriclayout(B::BidiagonalLayout{DV,EV}) where {DV,EV} = SymTridiagonalLayout{DV,EV}()
611+
hermitianlayout(::Type{<:Real}, B::BidiagonalLayout{DV,EV}) where {DV,EV} = SymTridiagonalLayout{DV,EV}()
606612
hermitianlayout(_, B::BidiagonalLayout) = HermitianLayout{typeof(B)}()
607613

614+
diagonaldata(D::Transpose) = diagonaldata(parent(D))
608615
subdiagonaldata(D::Transpose) = supdiagonaldata(parent(D))
609616
supdiagonaldata(D::Transpose) = subdiagonaldata(parent(D))
610-
diagonaldata(D::Transpose) = diagonaldata(parent(D))
611617

618+
diagonaldata(D::Adjoint{<:Real}) = diagonaldata(parent(D))
612619
subdiagonaldata(D::Adjoint{<:Real}) = supdiagonaldata(parent(D))
613620
supdiagonaldata(D::Adjoint{<:Real}) = subdiagonaldata(parent(D))
614-
diagonaldata(D::Adjoint{<:Real}) = diagonaldata(parent(D))
615621

616622
diagonaldata(S::HermOrSym) = diagonaldata(parent(S))
617623
subdiagonaldata(S::HermOrSym) = symmetricuplo(S) == 'L' ? subdiagonaldata(parent(S)) : supdiagonaldata(parent(S))
@@ -653,10 +659,14 @@ colsupport(::ZerosLayout, A, _) = 1:0
653659
rowsupport(::DiagonalLayout, _, k) = isempty(k) ? (1:0) : minimum(k):maximum(k)
654660
colsupport(::DiagonalLayout, _, j) = isempty(j) ? (1:0) : minimum(j):maximum(j)
655661

656-
colsupport(::BidiagonalLayout, A, j) =
662+
function colsupport(::BidiagonalLayout, A, j)
663+
isempty(j) && return 1:0
657664
bidiagonaluplo(A) == 'L' ? (minimum(j):min(size(A,1),maximum(j)+1)) : (max(minimum(j)-1,1):maximum(j))
658-
rowsupport(::BidiagonalLayout, A, j) =
665+
end
666+
function rowsupport(::BidiagonalLayout, A, j)
667+
isempty(j) && return 1:0
659668
bidiagonaluplo(A) == 'U' ? (minimum(j):min(size(A,2),maximum(j)+1)) : (max(minimum(j)-1,1):maximum(j))
669+
end
660670

661671
colsupport(::AbstractTridiagonalLayout, A, j) = max(minimum(j)-1,1):min(size(A,1),maximum(j)+1)
662672
rowsupport(::AbstractTridiagonalLayout, A, j) = max(minimum(j)-1,1):min(size(A,2),maximum(j)+1)

test/test_layoutarray.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,14 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
171171
C = randn(ComplexF64,5,5)
172172
@test ArrayLayouts.lmul!(2, Hermitian(copy(C))) == ArrayLayouts.rmul!(Hermitian(copy(C)), 2) == 2Hermitian(C)
173173

174-
if VERSION v"1.5"
175-
@test ldiv!(2, deepcopy(b)) == rdiv!(deepcopy(b), 2) == 2\b
176-
@test ldiv!(2, deepcopy(A)) == rdiv!(deepcopy(A), 2) == 2\A
177-
@test ldiv!(2, deepcopy(A)') == rdiv!(deepcopy(A)', 2) == 2\A'
178-
@test ldiv!(2, transpose(deepcopy(A))) == rdiv!(transpose(deepcopy(A)), 2) == 2\transpose(A)
179-
@test ldiv!(2, Symmetric(deepcopy(A))) == rdiv!(Symmetric(deepcopy(A)), 2) == 2\Symmetric(A)
180-
@test ldiv!(2, Hermitian(deepcopy(A))) == rdiv!(Hermitian(deepcopy(A)), 2) == 2\Hermitian(A)
181-
@test ArrayLayouts.ldiv!(2, Hermitian(copy(C))) == ArrayLayouts.rdiv!(Hermitian(copy(C)), 2) == 2\Hermitian(C)
182-
end
174+
175+
@test ldiv!(2, deepcopy(b)) == rdiv!(deepcopy(b), 2) == 2\b
176+
@test ldiv!(2, deepcopy(A)) == rdiv!(deepcopy(A), 2) == 2\A
177+
@test ldiv!(2, deepcopy(A)') == rdiv!(deepcopy(A)', 2) == 2\A'
178+
@test ldiv!(2, transpose(deepcopy(A))) == rdiv!(transpose(deepcopy(A)), 2) == 2\transpose(A)
179+
@test ldiv!(2, Symmetric(deepcopy(A))) == rdiv!(Symmetric(deepcopy(A)), 2) == 2\Symmetric(A)
180+
@test ldiv!(2, Hermitian(deepcopy(A))) == rdiv!(Hermitian(deepcopy(A)), 2) == 2\Hermitian(A)
181+
@test ArrayLayouts.ldiv!(2, Hermitian(copy(C))) == ArrayLayouts.rdiv!(Hermitian(copy(C)), 2) == 2\Hermitian(C)
183182
end
184183

185184
@testset "pow/I" begin

test/test_layouts.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ struct FooNumber <: Number end
139139
@testset "Triangular of Tridiagonal" begin
140140
U = UpperTriangular(T)
141141
L = LowerTriangular(T)
142-
@test MemoryLayout(U) isa BidiagonalLayout{DenseColumnMajor}
143-
@test MemoryLayout(L) isa BidiagonalLayout{DenseColumnMajor}
142+
@test MemoryLayout(U) isa BidiagonalLayout{DenseColumnMajor,DenseColumnMajor}
143+
@test MemoryLayout(L) isa BidiagonalLayout{DenseColumnMajor,DenseColumnMajor}
144144
@test diagonaldata(U) == diagonaldata(L) == diagonaldata(T)
145145
@test subdiagonaldata(L) == subdiagonaldata(T)
146146
@test supdiagonaldata(U) == supdiagonaldata(T)
@@ -304,6 +304,12 @@ struct FooNumber <: Number end
304304
v = SubArray(Fill(1,10),(1:3,))
305305
@test ArrayLayouts.sub_materialize(v) Fill(1,3)
306306
@test ArrayLayouts._copyto!(Vector{Float64}(undef,3), v) == ones(3)
307+
308+
T = Tridiagonal(Fill(1,10), Fill(2,11), Fill(3,10))
309+
@test MemoryLayout(UpperTriangular(T)) isa BidiagonalLayout{FillLayout,FillLayout}
310+
@test MemoryLayout(LowerTriangular(T)) isa BidiagonalLayout{FillLayout,FillLayout}
311+
@test MemoryLayout(UnitUpperTriangular(T)) isa BidiagonalLayout{FillLayout,FillLayout}
312+
@test MemoryLayout(UnitLowerTriangular(T)) isa BidiagonalLayout{FillLayout,FillLayout}
307313
end
308314

309315
@testset "Triangular col/rowsupport" begin

test/test_muladd.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ArrayLayouts, FillArrays, Random, Test
1+
using ArrayLayouts, FillArrays, Random, LinearAlgebra, Test
22
import ArrayLayouts: DenseColumnMajor, AbstractStridedLayout, AbstractColumnMajor, DiagonalLayout, mul, Mul, zero!
33

44
Random.seed!(0)
@@ -669,4 +669,21 @@ Random.seed!(0)
669669
@test copy(MulAdd(A,B̃)) == A*
670670
@test eltype(MulAdd(A,B̃)) == eltype(B̃)
671671
end
672+
673+
@testset "Bidiagonal" begin
674+
BidiagU = Bidiagonal(randn(5), randn(4), :U)
675+
BidiagL = Bidiagonal(randn(5), randn(4), :L)
676+
Tridiag = Tridiagonal(randn(4), randn(5), randn(4))
677+
SymTri = SymTridiagonal(randn(5), randn(4))
678+
Diag = Diagonal(randn(5))
679+
@test typeof(mul(BidiagU,Diag)) <: Bidiagonal
680+
@test typeof(mul(BidiagL,Diag)) <: Bidiagonal
681+
@test typeof(mul(Tridiag,Diag)) <: Tridiagonal
682+
@test typeof(mul(SymTri,Diag)) <: Tridiagonal
683+
684+
@test typeof(mul(BidiagU,Diag)) <: Bidiagonal
685+
@test typeof(mul(Diag,BidiagL)) <: Bidiagonal
686+
@test typeof(mul(Diag,Tridiag)) <: Tridiagonal
687+
@test typeof(mul(Diag,SymTri)) <: Tridiagonal
688+
end
672689
end

0 commit comments

Comments
 (0)