Skip to content

Commit 47de258

Browse files
authored
Cholesky (#63)
* Add cholesky to factorizations * Update factorizations.jl * Update factorizations.jl * add tests, cholesky default * Update test_layoutarray.jl
1 parent 219e0af commit 47de258

File tree

3 files changed

+35
-13
lines changed

3 files changed

+35
-13
lines changed

src/ArrayLayouts.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, broadcas
3535

3636
import LinearAlgebra: AbstractTriangular, AbstractQ, checksquare, pinv, fill!, tilebufsize, Abuf, Bbuf, Cbuf, factorize, qr, lu, cholesky,
3737
norm2, norm1, normInf, normMinusInf, qr, lu, qr!, lu!, AdjOrTrans, HermOrSym, copy_oftype,
38-
AdjointAbsVec, TransposeAbsVec
38+
AdjointAbsVec, TransposeAbsVec, cholcopy
3939

4040
import LinearAlgebra.BLAS: BlasFloat, BlasReal, BlasComplex
4141

src/factorizations.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,13 @@ end
281281
__qr(layout, lengths, A; kwds...) = Base.invoke(qr, Tuple{AbstractMatrix{eltype(A)}}, A; kwds...)
282282
_qr(layout, axes, A; kwds...) = __qr(layout, map(length, axes), A; kwds...)
283283
_qr(layout, axes, A, pivot::P; kwds...) where P = Base.invoke(qr, Tuple{AbstractMatrix{eltype(A)},P}, A, pivot; kwds...)
284+
_qr!(layout, axes, A, args...; kwds...) = error("Overload _qr!(::$(typeof(layout)), axes, A)")
284285
_lu(layout, axes, A; kwds...) = Base.invoke(lu, Tuple{AbstractMatrix{eltype(A)}}, A; kwds...)
285286
_lu(layout, axes, A, pivot::P; kwds...) where P = Base.invoke(lu, Tuple{AbstractMatrix{eltype(A)},P}, A, pivot; kwds...)
286-
_qr!(layout, axes, A, args...; kwds...) = error("Overload _qr!(::$(typeof(layout)), axes, A)")
287287
_lu!(layout, axes, A, args...; kwds...) = error("Overload _lu!(::$(typeof(layout)), axes, A)")
288+
_cholesky(layout, axes, A, ::Val{false}=Val(false); check::Bool = true) = cholesky!(cholcopy(A); check = check)
289+
_cholesky(layout, axes, A, ::Val{true}; tol = 0.0, check::Bool = true) = cholesky!(cholcopy(A), Val(true); tol = tol, check = check)
290+
_cholesky!(layout, axes, A, v::Val{tf}; kwds...) where tf = Base.invoke(cholesky!, Tuple{LinearAlgebra.RealHermSymComplexHerm,Val{tf}}, A, v; kwds...)
288291
_factorize(layout, axes, A) = qr(A) # Default to QR
289292

290293
_inv_eye(_, ::Type{T}, axs::NTuple{2,Base.OneTo{Int}}) where T = Matrix{T}(I, map(length,axs)...)
@@ -307,6 +310,8 @@ end
307310

308311
macro _layoutfactorizations(Typ)
309312
esc(quote
313+
LinearAlgebra.cholesky(A::$Typ, args...; kwds...) = ArrayLayouts._cholesky(ArrayLayouts.MemoryLayout(A), axes(A), A, args...; kwds...)
314+
LinearAlgebra.cholesky!(A::$Typ, v::Val{false}=Val(false); check::Bool = true) = ArrayLayouts._cholesky!(ArrayLayouts.MemoryLayout(A), axes(A), A, v; check=check)
310315
LinearAlgebra.qr(A::$Typ, args...; kwds...) = ArrayLayouts._qr(ArrayLayouts.MemoryLayout(A), axes(A), A, args...; kwds...)
311316
LinearAlgebra.qr!(A::$Typ, args...; kwds...) = ArrayLayouts._qr!(ArrayLayouts.MemoryLayout(A), axes(A), A, args...; kwds...)
312317
LinearAlgebra.lu(A::$Typ, pivot::Union{Val{false}, Val{true}}; kwds...) = ArrayLayouts._lu(ArrayLayouts.MemoryLayout(A), axes(A), A, pivot; kwds...)
@@ -321,5 +326,17 @@ macro layoutfactorizations(Typ)
321326
esc(quote
322327
ArrayLayouts.@_layoutfactorizations $Typ
323328
ArrayLayouts.@_layoutfactorizations SubArray{<:Any,2,<:$Typ}
329+
ArrayLayouts.@_layoutfactorizations LinearAlgebra.RealHermSymComplexHerm{<:Real,<:$Typ}
330+
ArrayLayouts.@_layoutfactorizations LinearAlgebra.RealHermSymComplexHerm{<:Real,<:SubArray{<:Real,2,<:$Typ}}
324331
end)
325-
end
332+
end
333+
334+
function ldiv!(C::Cholesky{<:Any,<:AbstractMatrix}, B::LayoutArray)
335+
if C.uplo == 'L'
336+
return ldiv!(adjoint(LowerTriangular(C.factors)), ldiv!(LowerTriangular(C.factors), B))
337+
else
338+
return ldiv!(UpperTriangular(C.factors), ldiv!(adjoint(UpperTriangular(C.factors)), B))
339+
end
340+
end
341+
342+

test/test_layoutarray.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,21 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
8181

8282
@test copyto!(view(MyMatrix(Array{Float64}(undef,5,5)),:,:), view(A',:,:)) == A'
8383

84-
@test qr(A).factors qr(A.A).factors
85-
@test qr(A,Val(true)).factors qr(A.A,Val(true)).factors
86-
@test lu(A).factors lu(A.A).factors
87-
@test lu(A,Val(true)).factors lu(A.A,Val(true)).factors
88-
@test_throws ErrorException qr!(A)
89-
@test_throws ErrorException lu!(A)
90-
91-
@test qr(A) isa LinearAlgebra.QRCompactWY
92-
@test inv(A) inv(A.A)
93-
84+
@testset "factorizations" begin
85+
@test qr(A).factors qr(A.A).factors
86+
@test qr(A,Val(true)).factors qr(A.A,Val(true)).factors
87+
@test lu(A).factors lu(A.A).factors
88+
@test lu(A,Val(true)).factors lu(A.A,Val(true)).factors
89+
@test_throws ErrorException qr!(A)
90+
@test_throws ErrorException lu!(A)
91+
92+
@test qr(A) isa LinearAlgebra.QRCompactWY
93+
@test inv(A) inv(A.A)
94+
95+
S = Symmetric(MyMatrix(reshape(inv.(1:25),5,5) + 10I))
96+
@test cholesky(S).U cholesky!(deepcopy(S)).U
97+
@test cholesky(S,Val(true)).U cholesky(Matrix(S),Val(true)).U
98+
end
9499
Bin = randn(5,5)
95100
B = MyMatrix(copy(Bin))
96101
muladd!(1.0, A, A, 2.0, B)

0 commit comments

Comments
 (0)