Skip to content

Commit 7094fa8

Browse files
authored
Support lu factorization for strided arrays (#78)
* support lu factorization for strided * add general ldiv! for LU * fix tests, default to Eye in UniformScaling addition * test LU * zeroeltype to support special cases for vector mul * fix cholesky tests * Tests pass on Julia v1.7 * increase coverage
1 parent 10259da commit 7094fa8

File tree

5 files changed

+112
-18
lines changed

5 files changed

+112
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.7.8"
4+
version = "0.7.7"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/ArrayLayouts.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, broadcas
3636

3737
import LinearAlgebra: AbstractTriangular, AbstractQ, checksquare, pinv, fill!, tilebufsize, Abuf, Bbuf, Cbuf, factorize, qr, lu, cholesky,
3838
norm2, norm1, normInf, normMinusInf, qr, lu, qr!, lu!, AdjOrTrans, HermOrSym, copy_oftype,
39-
AdjointAbsVec, TransposeAbsVec, cholcopy
39+
AdjointAbsVec, TransposeAbsVec, cholcopy, checknonsingular, _apply_ipiv_rows!, ipiv2perm, RealHermSymComplexHerm, chkfullrank
4040

4141
import LinearAlgebra.BLAS: BlasFloat, BlasReal, BlasComplex
4242

@@ -53,6 +53,13 @@ export materialize, materialize!, MulAdd, muladd!, Ldiv, Rdiv, Lmul, Rmul, Dot,
5353
colsupport, rowsupport, layout_getindex, QLayout, LayoutArray, LayoutMatrix, LayoutVector,
5454
RangeCumsum
5555

56+
if VERSION < v"1.7-"
57+
const ColumnNorm = Val{true}
58+
const RowMaximum = Val{true}
59+
const NoPivot = Val{false}
60+
end
61+
62+
5663
struct ApplyBroadcastStyle <: BroadcastStyle end
5764
@inline function copyto!(dest::AbstractArray, bc::Broadcasted{ApplyBroadcastStyle})
5865
@assert length(bc.args) == 1

src/factorizations.jl

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,21 @@ factors are stored with layout SLAY and τ stored with layout TLAY
1515
"""
1616
struct QRPackedLayout{SLAY,TLAY} <: AbstractQRLayout end
1717

18+
19+
"""
20+
LULayout{SLAY}()
21+
22+
represents a Packed QR factorization whose
23+
factors are stored with layout SLAY and τ stored with layout TLAY
24+
"""
25+
struct LULayout{SLAY} <: AbstractQRLayout end
26+
1827
MemoryLayout(::Type{<:LinearAlgebra.QRCompactWY{<:Any,MAT}}) where MAT =
1928
QRCompactWYLayout{typeof(MemoryLayout(MAT)),DenseColumnMajor}()
2029
MemoryLayout(::Type{<:LinearAlgebra.QR{<:Any,MAT}}) where MAT =
2130
QRPackedLayout{typeof(MemoryLayout(MAT)),DenseColumnMajor}()
31+
MemoryLayout(::Type{<:LinearAlgebra.LU{<:Any,MAT}}) where MAT =
32+
LULayout{typeof(MemoryLayout(MAT))}()
2233

2334
function materialize!(L::Ldiv{<:QRCompactWYLayout,<:Any,<:Any,<:AbstractVector})
2435
A,b = L.A, L.B
@@ -32,6 +43,15 @@ function materialize!(L::Ldiv{<:QRCompactWYLayout,<:Any,<:Any,<:AbstractMatrix})
3243
B
3344
end
3445

46+
materialize!(L::Ldiv{<:LULayout{<:AbstractColumnMajor},<:AbstractColumnMajor,<:LU{T},<:AbstractVecOrMat{T}}) where {T<:BlasFloat} =
47+
LAPACK.getrs!('N', L.A.factors, L.A.ipiv, L.B)
48+
49+
function materialize!(L::Ldiv{<:LULayout})
50+
A,B = L.A,L.B
51+
_apply_ipiv_rows!(A, B)
52+
ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), B))
53+
end
54+
3555
# Julia implementation similar to xgelsy
3656
function materialize!(Ldv::Ldiv{<:QRPackedLayout,<:Any,<:Any,<:AbstractMatrix{T}}) where T
3757
A,B = Ldv.A,Ldv.B
@@ -287,16 +307,69 @@ _lu(layout, axes, A, pivot::P; kwds...) where P = Base.invoke(lu, Tuple{Abstract
287307
_lu!(layout, axes, A, args...; kwds...) = error("Overload _lu!(::$(typeof(layout)), axes, A)")
288308
_cholesky(layout, axes, A, ::Val{false}=Val(false); check::Bool = true) = cholesky!(cholcopy(A); check = check)
289309
_cholesky(layout, axes, A, ::Val{true}; tol = 0.0, check::Bool = true) = cholesky!(cholcopy(A), Val(true); tol = tol, check = check)
290-
_cholesky!(layout, axes, A, v::Val{tf}; kwds...) where tf = Base.invoke(cholesky!, Tuple{LinearAlgebra.RealHermSymComplexHerm,Val{tf}}, A, v; kwds...)
291310
_factorize(layout, axes, A) = qr(A) # Default to QR
292311

312+
313+
_factorize(::AbstractStridedLayout, axes, A) = lu(A)
314+
function _lu!(::AbstractColumnMajor, axes, A::AbstractMatrix{T}, pivot::Union{NoPivot, RowMaximum} = RowMaximum();
315+
check::Bool = true) where T<:BlasFloat
316+
if pivot === NoPivot()
317+
return generic_lufact!(A, pivot; check = check)
318+
end
319+
lpt = LAPACK.getrf!(A)
320+
check && checknonsingular(lpt[3])
321+
return LU{T,typeof(A)}(lpt[1], lpt[2], lpt[3])
322+
end
323+
324+
# for some reason only defined for StridedMatrix in LinearAlgebra
325+
function getproperty(F::LU{T,<:LayoutMatrix}, d::Symbol) where T
326+
m, n = size(F)
327+
if d === :L
328+
L = tril!(getfield(F, :factors)[1:m, 1:min(m,n)])
329+
for i = 1:min(m,n); L[i,i] = one(T); end
330+
return L
331+
elseif d === :U
332+
return triu!(getfield(F, :factors)[1:min(m,n), 1:n])
333+
elseif d === :p
334+
return ipiv2perm(getfield(F, :ipiv), m)
335+
elseif d === :P
336+
return Matrix{T}(I, m, m)[:,invperm(F.p)]
337+
else
338+
getfield(F, d)
339+
end
340+
end
341+
342+
293343
# Cholesky factorization without pivoting (copied from stdlib/LinearAlgebra).
294-
function _cholesky!(layout, axes, A::LinearAlgebra.RealHermSymComplexHerm, ::Val{false}; check::Bool = true)
295-
C, info = LinearAlgebra._chol!(A.data, A.uplo == 'U' ? UpperTriangular : LowerTriangular)
344+
345+
# _chol!. Internal methods for calling unpivoted Cholesky
346+
## BLAS/LAPACK element types
347+
function _chol!(::SymmetricLayout{<:AbstractColumnMajor}, A::AbstractMatrix{<:BlasFloat}, ::Type{UpperTriangular})
348+
C, info = LAPACK.potrf!('U', A)
349+
return UpperTriangular(C), info
350+
end
351+
function _chol!(::SymmetricLayout{<:AbstractColumnMajor}, A::AbstractMatrix{<:BlasFloat}, ::Type{LowerTriangular})
352+
C, info = LAPACK.potrf!('L', A)
353+
return LowerTriangular(C), info
354+
end
355+
356+
_chol!(_, A, UL) = LinearAlgebra._chol!(A, UL)
357+
358+
function _cholesky!(layout, axes, A::RealHermSymComplexHerm, ::Val{false}; check::Bool = true)
359+
C, info = _chol!(layout, A.data, A.uplo == 'U' ? UpperTriangular : LowerTriangular)
296360
check && LinearAlgebra.checkpositivedefinite(info)
297361
return Cholesky(C.data, A.uplo, info)
298362
end
299363

364+
function _cholesky!(::SymmetricLayout{<:AbstractColumnMajor}, axes, A::AbstractMatrix{<:BlasReal},
365+
::Val{true}; tol = 0.0, check::Bool = true)
366+
AA, piv, rank, info = LAPACK.pstrf!(A.uplo, A.data, tol)
367+
C = CholeskyPivoted{eltype(AA),typeof(AA)}(AA, A.uplo, piv, rank, tol, info)
368+
check && chkfullrank(C)
369+
return C
370+
end
371+
372+
300373
_inv_eye(_, ::Type{T}, axs::NTuple{2,Base.OneTo{Int}}) where T = Matrix{T}(I, map(length,axs)...)
301374
function _inv_eye(A, ::Type{T}, (rows,cols)) where T
302375
dest = zero!(similar(A, T, (cols,rows)))
@@ -318,14 +391,16 @@ end
318391
macro _layoutfactorizations(Typ)
319392
esc(quote
320393
LinearAlgebra.cholesky(A::$Typ, args...; kwds...) = ArrayLayouts._cholesky(ArrayLayouts.MemoryLayout(A), axes(A), A, args...; kwds...)
321-
LinearAlgebra.cholesky!(A::$Typ, v::Val{false}=Val(false); check::Bool = true) = ArrayLayouts._cholesky!(ArrayLayouts.MemoryLayout(A), axes(A), A, v; check=check)
394+
LinearAlgebra.cholesky!(A::RealHermSymComplexHerm{<:Real,<:$Typ}, v::Val{false}=Val(false); check::Bool = true) = ArrayLayouts._cholesky!(ArrayLayouts.MemoryLayout(A), axes(A), A, v; check=check)
395+
LinearAlgebra.cholesky!(A::RealHermSymComplexHerm{<:Real,<:$Typ}, v::Val{true}; check::Bool = true, tol = 0.0) = ArrayLayouts._cholesky!(ArrayLayouts.MemoryLayout(A), axes(A), A, v; check=check, tol=tol)
322396
LinearAlgebra.qr(A::$Typ, args...; kwds...) = ArrayLayouts._qr(ArrayLayouts.MemoryLayout(A), axes(A), A, args...; kwds...)
323397
LinearAlgebra.qr!(A::$Typ, args...; kwds...) = ArrayLayouts._qr!(ArrayLayouts.MemoryLayout(A), axes(A), A, args...; kwds...)
324-
LinearAlgebra.lu(A::$Typ, pivot::Union{Val{false}, Val{true}}; kwds...) = ArrayLayouts._lu(ArrayLayouts.MemoryLayout(A), axes(A), A, pivot; kwds...)
398+
LinearAlgebra.lu(A::$Typ, pivot::Union{NoPivot,RowMaximum}; kwds...) = ArrayLayouts._lu(ArrayLayouts.MemoryLayout(A), axes(A), A, pivot; kwds...)
325399
LinearAlgebra.lu(A::$Typ{T}; kwds...) where T = ArrayLayouts._lu(ArrayLayouts.MemoryLayout(A), axes(A), A; kwds...)
326400
LinearAlgebra.lu!(A::$Typ, args...; kwds...) = ArrayLayouts._lu!(ArrayLayouts.MemoryLayout(A), axes(A), A, args...; kwds...)
327401
LinearAlgebra.factorize(A::$Typ) = ArrayLayouts._factorize(ArrayLayouts.MemoryLayout(A), axes(A), A)
328402
LinearAlgebra.inv(A::$Typ) = ArrayLayouts._inv(ArrayLayouts.MemoryLayout(A), axes(A), A)
403+
LinearAlgebra.ldiv!(L::LU{<:Any,<:$Typ}, B) = ArrayLayouts.ldiv!(L, B)
329404
end)
330405
end
331406

src/mul.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,13 @@ axes(M::Mul) = _mul_axes(axes(M.A), axes(M.B))
3030

3131
# The following design is to support QuasiArrays.jl where indices
3232
# may not be `Int`
33+
34+
zeroeltype(M) = zero(eltype(M)) # allow special casing where we know more about zero
35+
zeroeltype(M::Mul{<:Any,<:Any,<:SubArray}) = zeroeltype(Mul(parent(M.A), M.B))
36+
3337
function _getindex(::Type{Tuple{AA}}, M::Mul, (k,)::Tuple{AA}) where AA
3438
A,B = M.A, M.B
35-
ret = zero(eltype(M))
39+
ret = zeroeltype(M)
3640
for j = rowsupport(A, k) colsupport(B,1)
3741
ret += A[k,j] * B[j]
3842
end
@@ -41,7 +45,7 @@ end
4145

4246
function _getindex(::Type{Tuple{AA,BB}}, M::Mul, (k, j)::Tuple{AA,BB}) where {AA,BB}
4347
A,B = M.A,M.B
44-
ret = zero(eltype(M))
48+
ret = zeroeltype(M)
4549
@inbounds forin (rowsupport(A,k) colsupport(B,j))
4650
ret += A[k,ℓ] * B[ℓ,j]
4751
end
@@ -292,8 +296,9 @@ LinearAlgebra.dot(x::AbstractVector, A::Symmetric{<:Real,<:LayoutMatrix}, y::Abs
292296

293297
# allow overloading for infinite or lazy case
294298
@inline _power_by_squaring(_, _, A, p) = Base.invoke(Base.power_by_squaring, Tuple{AbstractMatrix,Integer}, A, p)
295-
@inline _apply(_, _, op, A::AbstractMatrix, Λ::UniformScaling) = Base.invoke(op, Tuple{AbstractMatrix,UniformScaling}, A, Λ)
296-
@inline _apply(_, _, op, Λ::UniformScaling, A::AbstractMatrix) = Base.invoke(op, Tuple{UniformScaling,AbstractMatrix}, Λ, A)
299+
# TODO: Remove unnecessary _apply
300+
_apply(_, _, op, Λ::UniformScaling, A::AbstractMatrix) = op(Diagonal(Fill.λ,size(A,1))), A)
301+
_apply(_, _, op, A::AbstractMatrix, Λ::UniformScaling) = op(A, Diagonal(Fill.λ,size(A,1))))
297302

298303
for Typ in (:LayoutMatrix, :(Symmetric{<:Any,<:LayoutMatrix}), :(Hermitian{<:Any,<:LayoutMatrix}),
299304
:(Adjoint{<:Any,<:LayoutMatrix}), :(Transpose{<:Any,<:LayoutMatrix}))

test/test_layoutarray.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
using ArrayLayouts, LinearAlgebra, FillArrays, Base64, Test
2-
import ArrayLayouts: sub_materialize
2+
import ArrayLayouts: sub_materialize, MemoryLayout, ColumnNorm, RowMaximum
33

4-
if VERSION < v"1.7-"
5-
ColumnNorm() = Val(true)
6-
RowMaximum() = Val(true)
7-
end
84

95
struct MyMatrix <: LayoutMatrix{Float64}
106
A::Matrix{Float64}
@@ -16,6 +12,7 @@ Base.size(A::MyMatrix) = size(A.A)
1612
Base.strides(A::MyMatrix) = strides(A.A)
1713
Base.unsafe_convert(::Type{Ptr{T}}, A::MyMatrix) where T = Base.unsafe_convert(Ptr{T}, A.A)
1814
MemoryLayout(::Type{MyMatrix}) = DenseColumnMajor()
15+
Base.copy(A::MyMatrix) = MyMatrix(copy(A.A))
1916

2017
struct MyVector <: LayoutVector{Float64}
2118
A::Vector{Float64}
@@ -93,14 +90,24 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
9390
@test lu(A).factors lu(A.A).factors
9491
@test lu(A,RowMaximum()).factors lu(A.A,RowMaximum()).factors
9592
@test_throws ErrorException qr!(A)
96-
@test_throws ErrorException lu!(A)
93+
@test lu!(copy(A)).factors lu(A.A).factors
94+
b = randn(5)
95+
@test all(A \ b .≡ A.A \ b)
96+
@test all(lu(A).L .≡ lu(A.A).L)
97+
@test all(lu(A).U .≡ lu(A.A).U)
98+
@test lu(A).p == lu(A.A).p
99+
@test lu(A).P == lu(A.A).P
97100

98101
@test qr(A) isa LinearAlgebra.QRCompactWY
99102
@test inv(A) inv(A.A)
100103

101104
S = Symmetric(MyMatrix(reshape(inv.(1:25),5,5) + 10I))
102105
@test cholesky(S).U @inferred(cholesky!(deepcopy(S))).U
103106
@test cholesky(S,Val(true)).U cholesky(Matrix(S),Val(true)).U
107+
108+
S = Symmetric(MyMatrix(reshape(inv.(1:25),5,5) + 10I),:L)
109+
@test cholesky(S).U @inferred(cholesky!(deepcopy(S))).U
110+
@test cholesky(S,Val(true)).U cholesky(Matrix(S),Val(true)).U
104111
end
105112
Bin = randn(5,5)
106113
B = MyMatrix(copy(Bin))
@@ -208,7 +215,7 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
208215
C = randn(ComplexF64,5,5)
209216
@test ArrayLayouts.lmul!(2, Hermitian(copy(C))) == ArrayLayouts.rmul!(Hermitian(copy(C)), 2) == 2Hermitian(C)
210217

211-
218+
212219
@test ldiv!(2, deepcopy(b)) == rdiv!(deepcopy(b), 2) == 2\b
213220
@test ldiv!(2, deepcopy(A)) == rdiv!(deepcopy(A), 2) == 2\A
214221
@test ldiv!(2, deepcopy(A)') == rdiv!(deepcopy(A)', 2) == 2\A'

0 commit comments

Comments
 (0)