diff --git a/src/lobpcg.jl b/src/lobpcg.jl index 57eefec2..37670a3c 100644 --- a/src/lobpcg.jl +++ b/src/lobpcg.jl @@ -337,9 +337,14 @@ function (g::BlockGram)(gram, n1::Int, n2::Int, n3::Int, normalized::Bool=true) return end +# Any orthogonalization algorithm should guarantee that X' B X is the identity +# A * X and B * X should be updated accordingly abstract type AbstractOrtho end -struct CholQR{TA} <: AbstractOrtho - gramVBV::TA # to be used in view + +struct CholQROrtho{TA, TB, TM} <: AbstractOrtho + A::TA + B::TB + gramVBV::TM # to be used in view end function rdiv!(A, B::UpperTriangular) @@ -362,7 +367,11 @@ function realdiag!(M::AbstractMatrix{TC}) where TC <: Complex return M end -function (ortho!::CholQR)(XBlocks::Blocks{Generalized}, sizeX = -1; update_AX=false, update_BX=false) where Generalized +# Orthogonalizes X s.t. X' B X = I +# New orthogonal X must span at least the same span as the old X +# If update_AX is true, A * X is updated +# If update_BX is true, B * X is updated +function (ortho!::CholQROrtho)(XBlocks::Blocks{Generalized}, sizeX = -1; update_AX=false, update_BX=false) where Generalized useview = sizeX != -1 if sizeX == -1 sizeX = size(XBlocks.block, 2) @@ -370,23 +379,85 @@ function (ortho!::CholQR)(XBlocks::Blocks{Generalized}, sizeX = -1; update_AX=fa X = XBlocks.block BX = XBlocks.B_block # Assumes it is premultiplied AX = XBlocks.A_block + A = ortho!.A + B = ortho!.B + T = real(eltype(X)) @views gram_view = ortho!.gramVBV[1:sizeX, 1:sizeX] + + # Compute the gram matrix @views if useview mul!(gram_view, adjoint(X[:, 1:sizeX]), BX[:, 1:sizeX]) else mul!(gram_view, adjoint(X), BX) end + # Ensure the diagonal is real, may not be due to computational error realdiag!(gram_view) - cholf = cholesky!(Hermitian(gram_view)) - R = cholf.factors - @views if useview - rdiv!(X[:, 1:sizeX], UpperTriangular(R)) - update_AX && rdiv!(AX[:, 1:sizeX], UpperTriangular(R)) - Generalized && update_BX && rdiv!(BX[:, 1:sizeX], UpperTriangular(R)) + + # Compute the Cholesky factors R' * R = X' * B * X + # inv(R') * X' * B * X * inv(R) = I + cholf = cholesky!(Hermitian(gram_view), check=false) + if issuccess(cholf) + # Update X to X * inv(R) + # Optionally update A * X to A * X * inv(R) + # Optionally update B * X to B * X * inv(R) + R = cholf.factors + @views if useview + rdiv!(X[:, 1:sizeX], UpperTriangular(R)) + update_AX && rdiv!(AX[:, 1:sizeX], UpperTriangular(R)) + Generalized && update_BX && rdiv!(BX[:, 1:sizeX], UpperTriangular(R)) + else + rdiv!(X, UpperTriangular(R)) + update_AX && rdiv!(AX, UpperTriangular(R)) + Generalized && update_BX && rdiv!(BX, UpperTriangular(R)) + end else - rdiv!(X, UpperTriangular(R)) - update_AX && rdiv!(AX, UpperTriangular(R)) - Generalized && update_BX && rdiv!(BX, UpperTriangular(R)) + # X is nearly not full-column rank + # Find the QR decomposition of X + # New X is the B-orthonormlized compact Q from the QR decomposition + qrf = qr!(X) + @views if useview + # Extract compact Q into X + X[:, 1:sizeX] .= 0 + I!(X, sizeX) + lmul!(qrf.Q, X[:, 1:sizeX]) + if Generalized + # Find X' B X + mul!(BX[:, 1:sizeX], B, X[:, 1:sizeX]) + mul!(gram_view, adjoint(X[:, 1:sizeX]), BX[:, 1:sizeX]) + realdiag!(gram_view) + # Find R' * R = X' * B * X + cholf = cholesky!(Hermitian(gram_view), check=true) + R = cholf.factors + # Update X to X * inv(R) + # New X is full column rank and B-orthonormal + rdiv!(X[:, 1:sizeX], UpperTriangular(R)) + # Optionally update B * X + update_BX && rdiv!(BX[:, 1:sizeX], UpperTriangular(R)) + end + # Optionally update A * X + update_AX && mul!(AX[:,1:sizeX], A, X[:,1:sizeX]) + else + # Extract compact Q into X + X .= 0 + I!(X, size(X, 2)) + lmul!(qrf.Q, X) + if Generalized + # Find X' B X + mul!(BX, B, X) + mul!(gram_view, adjoint(X), BX) + realdiag!(gram_view) + # Find R' * R = X' * B * X + cholf = cholesky!(Hermitian(gram_view), check=true) + R = cholf.factors + # Update X to X * inv(R) + # New X is full column rank and B-orthonormal + rdiv!(X, UpperTriangular(R)) + # Optionally update B * X + update_BX && rdiv!(BX, UpperTriangular(R)) + end + # Optionally update A * X + update_AX && mul!(AX, A, X) + end end return @@ -478,7 +549,7 @@ function LOBPCGIterator(A, B, largest::Bool, X, precond!::RPreconditioner, const iteration = Ref(1) currentBlockSize = Ref(nev) generalized = !(B isa Nothing) - ortho! = CholQR(zeros(T, nev, nev)) + ortho! = CholQROrtho(A, B, zeros(T, nev, nev)) gramABlock = BlockGram(XBlocks) gramBBlock = BlockGram(XBlocks)