Skip to content

Commit 5e0f275

Browse files
authored
IndexCartesian implementation of generic muladd (#24)
* Cartesian blasmul!, symmetric col/rowsupport, copyto! in Ldiv * v0.3.5 * Triangular getindex * Add Bidiagonal support * Add tests
1 parent d8b862b commit 5e0f275

File tree

8 files changed

+628
-503
lines changed

8 files changed

+628
-503
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.3.4"
4+
version = "0.3.5"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/ArrayLayouts.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,8 @@ include("factorizations.jl")
104104

105105
@inline layout_getindex(A, I...) = sub_materialize(view(A, I...))
106106

107-
108-
109-
macro layoutmatrix(Typ)
107+
macro _layoutgetindex(Typ)
110108
esc(quote
111-
ArrayLayouts.@layoutldiv $Typ
112-
ArrayLayouts.@layoutmul $Typ
113-
ArrayLayouts.@layoutlmul $Typ
114-
ArrayLayouts.@layoutfactorizations $Typ
115-
116109
@inline Base.getindex(A::$Typ, kr::Colon, jr::Colon) = ArrayLayouts.layout_getindex(A, kr, jr)
117110
@inline Base.getindex(A::$Typ, kr::Colon, jr::AbstractUnitRange) = ArrayLayouts.layout_getindex(A, kr, jr)
118111
@inline Base.getindex(A::$Typ, kr::AbstractUnitRange, jr::Colon) = ArrayLayouts.layout_getindex(A, kr, jr)
@@ -123,6 +116,24 @@ macro layoutmatrix(Typ)
123116
end)
124117
end
125118

119+
macro layoutgetindex(Typ)
120+
esc(quote
121+
ArrayLayouts.@_layoutgetindex $Typ
122+
ArrayLayouts.@_layoutgetindex LinearAlgebra.AbstractTriangular{<:Any,<:$Typ}
123+
end)
124+
end
125+
126+
127+
macro layoutmatrix(Typ)
128+
esc(quote
129+
ArrayLayouts.@layoutldiv $Typ
130+
ArrayLayouts.@layoutmul $Typ
131+
ArrayLayouts.@layoutlmul $Typ
132+
ArrayLayouts.@layoutfactorizations $Typ
133+
ArrayLayouts.@layoutgetindex $Typ
134+
end)
135+
end
136+
126137
@layoutmatrix LayoutMatrix
127138

128139
getindex(A::LayoutVector, kr::AbstractVector) = layout_getindex(A, kr)

src/diagonal.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11

2-
rowsupport(::DiagonalLayout, _, k) = isempty(k) ? (1:0) : minimum(k):maximum(k)
3-
colsupport(::DiagonalLayout, _, j) = isempty(j) ? (1:0) : minimum(j):maximum(j)
4-
52
###
63
# Lmul
74
####

src/memorylayout.jl

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -500,32 +500,47 @@ abstract type AbstractBandedLayout <: MemoryLayout end
500500
abstract type AbstractTridiagonalLayout <: AbstractBandedLayout end
501501

502502
struct DiagonalLayout{ML} <: AbstractBandedLayout end
503+
struct BidiagonalLayout{ML} <: AbstractBandedLayout end
503504
struct SymTridiagonalLayout{ML} <: AbstractTridiagonalLayout end
504505
struct TridiagonalLayout{ML} <: AbstractTridiagonalLayout end
505506

506507
diagonallayout(_) = DiagonalLayout{UnknownLayout}()
507508
diagonallayout(::ML) where ML<:AbstractStridedLayout = DiagonalLayout{ML}()
508509
MemoryLayout(D::Type{Diagonal{T,P}}) where {T,P} = diagonallayout(MemoryLayout(P))
509-
diagonaldata(D::Diagonal) = parent(D)
510-
510+
MemoryLayout(::Type{Bidiagonal{T,V}}) where {T,V} = BidiagonalLayout{typeof(MemoryLayout(V))}()
511511
MemoryLayout(::Type{SymTridiagonal{T,P}}) where {T,P} = SymTridiagonalLayout{typeof(MemoryLayout(P))}()
512+
MemoryLayout(::Type{Tridiagonal{T,P}}) where {T,P} = TridiagonalLayout{typeof(MemoryLayout(P))}()
513+
514+
bidiagonaluplo(A::Bidiagonal) = A.uplo
515+
bidiagonaluplo(A::AdjOrTrans) = bidiagonaluplo(parent(A)) == 'L' ? 'U' : 'L'
516+
517+
diagonaldata(D::Diagonal) = parent(D)
518+
diagonaldata(D::Bidiagonal) = D.dv
512519
diagonaldata(D::SymTridiagonal) = D.dv
520+
diagonaldata(D::Tridiagonal) = D.d
521+
522+
supdiagonaldata(D::Bidiagonal) = D.uplo == 'U' ? D.ev : throw(ArgumentError("$D is lower-bidiagonal"))
523+
subdiagonaldata(D::Bidiagonal) = D.uplo == 'L' ? D.ev : throw(ArgumentError("$D is upper-bidiagonal"))
524+
513525
supdiagonaldata(D::SymTridiagonal) = D.ev
514526
subdiagonaldata(D::SymTridiagonal) = D.ev
515527

516-
MemoryLayout(::Type{Tridiagonal{T,P}}) where {T,P} = TridiagonalLayout{typeof(MemoryLayout(P))}()
517-
diagonaldata(D::Tridiagonal) = D.d
518528
subdiagonaldata(D::Tridiagonal) = D.dl
519529
supdiagonaldata(D::Tridiagonal) = D.du
520530

521-
522531
transposelayout(ml::DiagonalLayout) = ml
532+
transposelayout(ml::BidiagonalLayout) = ml
523533
transposelayout(ml::SymTridiagonalLayout) = ml
524534
transposelayout(ml::TridiagonalLayout) = ml
525535
transposelayout(ml::ConjLayout{DiagonalLayout}) = ml
526536

527537
adjointlayout(::Type{<:Real}, ml::SymTridiagonalLayout) = ml
528538
adjointlayout(::Type{<:Real}, ml::TridiagonalLayout) = ml
539+
adjointlayout(::Type{<:Real}, ml::BidiagonalLayout) = ml
540+
541+
symmetriclayout(B::BidiagonalLayout{ML}) where ML = SymTridiagonalLayout{ML}()
542+
hermitianlayout(::Type{<:Real}, B::BidiagonalLayout{ML}) where ML = SymTridiagonalLayout{ML}()
543+
hermitianlayout(_, B::BidiagonalLayout) = HermitianLayout{typeof(B)}()
529544

530545
subdiagonaldata(D::Transpose) = supdiagonaldata(parent(D))
531546
supdiagonaldata(D::Transpose) = subdiagonaldata(parent(D))
@@ -535,6 +550,10 @@ subdiagonaldata(D::Adjoint{<:Real}) = supdiagonaldata(parent(D))
535550
supdiagonaldata(D::Adjoint{<:Real}) = subdiagonaldata(parent(D))
536551
diagonaldata(D::Adjoint{<:Real}) = diagonaldata(parent(D))
537552

553+
diagonaldata(S::HermOrSym) = diagonaldata(parent(S))
554+
subdiagonaldata(S::HermOrSym) = symmetricuplo(S) == 'L' ? subdiagonaldata(parent(S)) : supdiagonaldata(parent(S))
555+
supdiagonaldata(S::HermOrSym) = symmetricuplo(S) == 'L' ? subdiagonaldata(parent(S)) : supdiagonaldata(parent(S))
556+
538557
###
539558
# Fill
540559
####
@@ -575,3 +594,20 @@ colsupport(A) = colsupport(A, axes(A,2))
575594

576595
rowsupport(::ZerosLayout, A, _) = 1:0
577596
colsupport(::ZerosLayout, A, _) = 1:0
597+
598+
rowsupport(::DiagonalLayout, _, k) = isempty(k) ? (1:0) : minimum(k):maximum(k)
599+
colsupport(::DiagonalLayout, _, j) = isempty(j) ? (1:0) : minimum(j):maximum(j)
600+
601+
colsupport(::BidiagonalLayout, A, j) =
602+
bidiagonaluplo(A) == 'L' ? (minimum(j):min(size(A,1),maximum(j)+1)) : (max(minimum(j)-1,1):maximum(j))
603+
rowsupport(::BidiagonalLayout, A, j) =
604+
bidiagonaluplo(A) == 'U' ? (minimum(j):min(size(A,2),maximum(j)+1)) : (max(minimum(j)-1,1):maximum(j))
605+
606+
colsupport(::AbstractTridiagonalLayout, A, j) = max(minimum(j)-1,1):min(size(A,1),maximum(j)+1)
607+
rowsupport(::AbstractTridiagonalLayout, A, j) = max(minimum(j)-1,1):min(size(A,2),maximum(j)+1)
608+
609+
colsupport(::SymmetricLayout, A, j) = first(colsupport(symmetricdata(A),j)):last(rowsupport(symmetricdata(A),j))
610+
rowsupport(::SymmetricLayout, A, j) = colsupport(A, j)
611+
612+
colsupport(::HermitianLayout, A, j) = first(colsupport(hermitiandata(A),j)):last(rowsupport(hermitiandata(A),j))
613+
rowsupport(::HermitianLayout, A, j) = colsupport(A, j)

src/muladd.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ function default_blasmul!(α, A::AbstractMatrix, B::AbstractMatrix, β, C::Abstr
194194
C
195195
end
196196

197-
function default_blasmul!(α, A::AbstractMatrix, B::AbstractVector, β, C::AbstractVector)
197+
function _default_blasmul!(::IndexLinear, α, A::AbstractMatrix, B::AbstractVector, β, C::AbstractVector)
198198
mA, nA = size(A)
199199
mB = length(B)
200200
nA == mB || throw(DimensionMismatch("Dimensions must match"))
@@ -217,6 +217,30 @@ function default_blasmul!(α, A::AbstractMatrix, B::AbstractVector, β, C::Abstr
217217
C
218218
end
219219

220+
function _default_blasmul!(::IndexCartesian, α, A::AbstractMatrix, B::AbstractVector, β, C::AbstractVector)
221+
mA, nA = size(A)
222+
mB = length(B)
223+
nA == mB || throw(DimensionMismatch("Dimensions must match"))
224+
length(C) == mA || throw(DimensionMismatch("Dimensions must match"))
225+
226+
lmul!(β, C)
227+
(nA == 0 || mB == 0) && return C
228+
229+
z = zero(A[1,1]*B[1] + A[1,1]*B[1])
230+
231+
@inbounds for k in colsupport(B,1)
232+
b = B[k]
233+
for i = colsupport(A,k)
234+
C[i] += α * A[i,k] * b
235+
end
236+
end
237+
238+
C
239+
end
240+
241+
default_blasmul!(α, A::AbstractMatrix, B::AbstractVector, β, C::AbstractVector) =
242+
_default_blasmul!(Base.IndexStyle(typeof(A)), α, A, B, β, C)
243+
220244
function materialize!(M::MatMulMatAdd)
221245
α, A, B, β, C = M.α, M.A, M.B, M.β, M.C
222246
if C B

src/triangular.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ materialize!(M::MatRmulMat{<:AbstractStridedLayout,<:TriangularLayout}) = Linear
118118

119119
@inline function copyto!(dest::AbstractArray, M::Ldiv{<:TriangularLayout})
120120
A, B = M.A, M.B
121-
dest B || (dest .= B)
121+
dest B || copyto!(dest, B)
122122
ldiv!(A, dest)
123123
end
124124

test/test_layouts.jl

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import ArrayLayouts: MemoryLayout, DenseRowMajor, DenseColumnMajor, StridedLayou
66
UnitLowerTriangularLayout, ScalarLayout, UnknownLayout,
77
hermitiandata, symmetricdata, FillLayout, ZerosLayout,
88
DiagonalLayout, TridiagonalLayout, SymTridiagonalLayout, colsupport, rowsupport,
9-
diagonaldata, subdiagonaldata, supdiagonaldata
9+
diagonaldata, subdiagonaldata, supdiagonaldata, BidiagonalLayout, bidiagonaluplo
1010

1111
struct FooBar end
1212
struct FooNumber <: Number end
@@ -80,7 +80,45 @@ struct FooNumber <: Number end
8080
@test MemoryLayout(view(randn(5)',[1,3])) == UnknownLayout()
8181
end
8282

83-
@testset "Symmetric/Hermitian MemoryLayout" begin
83+
@testset "Bi/Tridiagonal" begin
84+
T = Tridiagonal(randn(5),randn(6),randn(5))
85+
S = SymTridiagonal(T.d, T.du)
86+
Bl = Bidiagonal(T.d, T.dl, :L)
87+
Bu = Bidiagonal(T.d, T.du, :U)
88+
89+
@test MemoryLayout(T) isa TridiagonalLayout
90+
@test MemoryLayout(Adjoint(T)) isa TridiagonalLayout
91+
@test MemoryLayout(Transpose(T)) isa TridiagonalLayout
92+
@test MemoryLayout(S) isa SymTridiagonalLayout
93+
@test MemoryLayout(Adjoint(S)) isa SymTridiagonalLayout
94+
@test MemoryLayout(Transpose(S)) isa SymTridiagonalLayout
95+
@test MemoryLayout(Bl) isa BidiagonalLayout
96+
@test MemoryLayout(Adjoint(Bl)) isa BidiagonalLayout
97+
@test MemoryLayout(Transpose(Bl)) isa BidiagonalLayout
98+
@test MemoryLayout(Bu) isa BidiagonalLayout
99+
@test MemoryLayout(Adjoint(Bu)) isa BidiagonalLayout
100+
@test MemoryLayout(Transpose(Bu)) isa BidiagonalLayout
101+
102+
@test bidiagonaluplo(Bl) == bidiagonaluplo(Adjoint(Bu)) == 'L'
103+
@test bidiagonaluplo(Bu) == bidiagonaluplo(Adjoint(Bl)) == 'U'
104+
105+
@test diagonaldata(T) == diagonaldata(T') == diagonaldata(S) == diagonaldata(Bl) == diagonaldata(Bu)
106+
@test supdiagonaldata(T) == subdiagonaldata(Adjoint(T)) == subdiagonaldata(Transpose(T)) ==
107+
supdiagonaldata(S) == subdiagonaldata(S) ==
108+
supdiagonaldata(Bu) == subdiagonaldata(Adjoint(Bu)) == subdiagonaldata(Transpose(Bu))
109+
@test subdiagonaldata(T) == supdiagonaldata(Adjoint(T)) == supdiagonaldata(Transpose(T)) ==
110+
subdiagonaldata(Bl) == supdiagonaldata(Adjoint(Bl)) == supdiagonaldata(Transpose(Bl)) ==
111+
T.dl
112+
113+
@test colsupport(T,3) == rowsupport(T,3) == colsupport(S,3) == rowsupport(S,3) == 2:4
114+
@test colsupport(T,3:6) == rowsupport(T,3:6) == colsupport(S,3:6) == rowsupport(S,3:6) == 2:6
115+
@test colsupport(Bl,3) == rowsupport(Bu,3) == rowsupport(Adjoint(Bl),3) == 3:4
116+
@test rowsupport(Bl,3) == colsupport(Bu,3) == colsupport(Adjoint(Bl),3) == 2:3
117+
@test colsupport(Bl,3:6) == rowsupport(Bu,3:6) == 3:6
118+
@test colsupport(Bu,3:6) == rowsupport(Bl,3:6) == 2:6
119+
end
120+
121+
@testset "Symmetric/Hermitian" begin
84122
A = [1.0 2; 3 4]
85123
@test MemoryLayout(Symmetric(A)) == SymmetricLayout{DenseColumnMajor}()
86124
@test MemoryLayout(Hermitian(A)) == SymmetricLayout{DenseColumnMajor}()
@@ -108,6 +146,11 @@ struct FooNumber <: Number end
108146
@test symmetricdata(Symmetric(transpose(A))) transpose(A)
109147
@test symmetricdata(Hermitian(transpose(A))) transpose(A)
110148

149+
@test colsupport(Symmetric(A),2) colsupport(Symmetric(A),1:2)
150+
rowsupport(Symmetric(A),2) rowsupport(Symmetric(A),1:2) 1:2
151+
@test colsupport(Hermitian(A),2) colsupport(Hermitian(A),1:2)
152+
rowsupport(Hermitian(A),2) rowsupport(Hermitian(A),1:2) 1:2
153+
111154
B = [1.0+im 2; 3 4]
112155
@test MemoryLayout(Symmetric(B)) == SymmetricLayout{DenseColumnMajor}()
113156
@test MemoryLayout(Hermitian(B)) == HermitianLayout{DenseColumnMajor}()
@@ -132,6 +175,26 @@ struct FooNumber <: Number end
132175
@test hermitiandata(Hermitian(B')) B'
133176
@test symmetricdata(Symmetric(transpose(B))) transpose(B)
134177
@test hermitiandata(Hermitian(transpose(B))) transpose(B)
178+
179+
@testset "Bidiagonal" begin
180+
B = Bidiagonal(randn(6),randn(5),:U)
181+
Bc = Bidiagonal(randn(6) .+ 0im,randn(5) .+ 1im,:U)
182+
S = Symmetric(B)
183+
H = Hermitian(B)
184+
Sc = Symmetric(Bc)
185+
Hc = Hermitian(Bc)
186+
187+
@test MemoryLayout(S) isa SymTridiagonalLayout
188+
@test MemoryLayout(H) isa SymTridiagonalLayout
189+
@test MemoryLayout(Sc) isa SymTridiagonalLayout
190+
@test MemoryLayout(Hc) isa HermitianLayout
191+
192+
@test diagonaldata(S) == diagonaldata(B)
193+
@test subdiagonaldata(S) == supdiagonaldata(S) == supdiagonaldata(B)
194+
195+
@test colsupport(S,3) == colsupport(H,3) == colsupport(Sc,3) == colsupport(Hc,3) == 2:4
196+
@test rowsupport(S,3) == rowsupport(H,3) == rowsupport(Sc,3) == rowsupport(Hc,3) == 2:4
197+
end
135198
end
136199

137200
@testset "triangular MemoryLayout" begin
@@ -233,19 +296,4 @@ struct FooNumber <: Number end
233296
MemoryLayout(revD)
234297
@test 0 == @allocated MemoryLayout(revD)
235298
end
236-
237-
@testset "Tridiagonal" begin
238-
T = Tridiagonal(randn(5),randn(6),randn(5))
239-
S = SymTridiagonal(T.d, T.du)
240-
@test MemoryLayout(T) isa TridiagonalLayout
241-
@test MemoryLayout(Adjoint(T)) isa TridiagonalLayout
242-
@test MemoryLayout(Transpose(T)) isa TridiagonalLayout
243-
@test MemoryLayout(S) isa SymTridiagonalLayout
244-
@test MemoryLayout(Adjoint(S)) isa SymTridiagonalLayout
245-
@test MemoryLayout(Transpose(S)) isa SymTridiagonalLayout
246-
247-
@test diagonaldata(T) == diagonaldata(T') == diagonaldata(S)
248-
@test supdiagonaldata(T) == subdiagonaldata(Adjoint(T)) == subdiagonaldata(Transpose(T)) == supdiagonaldata(S) == subdiagonaldata(S)
249-
@test subdiagonaldata(T) == supdiagonaldata(Adjoint(T)) == supdiagonaldata(Transpose(T)) == T.dl
250-
end
251299
end

0 commit comments

Comments
 (0)