Skip to content

Commit cacaffc

Browse files
committed
Make inv and solves work for RFP
1 parent 6675052 commit cacaffc

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

src/lapack.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module LAPACK2
22

33
using Base.LinAlg: BlasInt, chkstride1, LAPACKException
44
using Base.LinAlg.BLAS: @blasfunc
5-
using Base.LinAlg.LAPACK: chkdiag, chkuplo
5+
using Base.LinAlg.LAPACK: chkdiag, chkside, chkuplo
66

77
# LAPACK wrappers
88
import Base.BLAS.@blasfunc
@@ -604,12 +604,19 @@ for (f, elty) in ((:dtfsm_, :Float64),
604604
(:ctfsm_, :Complex64))
605605

606606
@eval begin
607-
function pftrs!(transr::Char, side::Char, uplo::Char, trans::Char, diag::Char, alpha::Real, A::StridedVector{$elty}, B::StridedMatrix{$elty})
607+
function tfsm!(transr::Char,
608+
side::Char,
609+
uplo::Char,
610+
trans::Char,
611+
diag::Char,
612+
alpha::$elty,
613+
A::StridedVector{$elty},
614+
B::StridedVecOrMat{$elty})
608615
chkuplo(uplo)
609616
chkside(side)
610617
chkdiag(diag)
611618
chkstride1(B)
612-
m, n = size(B)
619+
m, n = size(B, 1), size(B, 2)
613620
if round(Int, div(sqrt(8length(A)), 2)) != m
614621
throw(DimensionMismatch("First dimension of B must equal $(round(Int, div(sqrt(8length(A)), 2))), got $m"))
615622
end
@@ -623,7 +630,7 @@ for (f, elty) in ((:dtfsm_, :Float64),
623630
&diag, &m, &n, &alpha,
624631
A, B, &ldb)
625632

626-
B
633+
return B
627634
end
628635
end
629636
end

src/rectfullpacked.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ end
128128
Base.LinAlg.inv!(A::TriangularRFP) = TriangularRFP(LAPACK2.tftri!(A.transr, A.uplo, 'N', A.data), A.transr, A.uplo)
129129
Base.LinAlg.inv(A::TriangularRFP) = Base.LinAlg.inv!(copy(A))
130130

131+
A_ldiv_B!(A::TriangularRFP{T}, B::StridedVecOrMat{T}) where T =
132+
LAPACK2.tfsm!(A.transr, 'L', A.uplo, 'N', 'N', one(T), A.data, B)
133+
(\)(A::TriangularRFP, B::StridedVecOrMat) = A_ldiv_B!(A, copy(B))
134+
131135
struct CholeskyRFP{T<:BlasFloat} <: Factorization{T}
132136
data::Vector{T}
133137
transr::Char
@@ -141,8 +145,8 @@ Base.LinAlg.factorize(A::HermitianRFP) = cholfact(A)
141145
Base.copy(F::CholeskyRFP{T}) where T = CholeskyRFP{T}(copy(F.data), F.transr, F.uplo)
142146

143147
# Solve
144-
\(A::CholeskyRFP, B::StridedVecOrMat) = LAPACK2.pftrs!(A.transr, A.uplo, A.data, copy(B))
145-
\(A::HermitianRFP, B::StridedVecOrMat) = cholfact(A)\B
148+
(\)(A::CholeskyRFP, B::StridedVecOrMat) = LAPACK2.pftrs!(A.transr, A.uplo, A.data, copy(B))
149+
(\)(A::HermitianRFP, B::StridedVecOrMat) = cholfact(A)\B
146150

147151
Base.LinAlg.inv!(A::CholeskyRFP) = HermitianRFP(LAPACK2.pftri!(A.transr, A.uplo, A.data), A.transr, A.uplo)
148152
Base.LinAlg.inv(A::CholeskyRFP) = Base.LinAlg.inv!(copy(A))

test/rectfullpacked.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,14 @@ import LinearAlgebra: Ac_mul_A_RFP, TriangularRFP
6060
n in (6, 7),
6161
uplo in (:L, :U)
6262

63-
A = triu(rand(elty, n, n))
64-
A_RFP = TriangularRFP(A)
63+
A = lufact(rand(elty, n, n))[:U]
64+
A = uplo == :U ? A : A'
65+
A_RFP = TriangularRFP(A, uplo)
6566
o = ones(elty, n)
6667

6768
@test_broken A A_RFP
6869
@test A full(A_RFP)
70+
@test A\o A_RFP\o
6971
@test inv(A) full(inv(A_RFP))
7072
end
7173
end

0 commit comments

Comments
 (0)