Skip to content

Commit 81f30f3

Browse files
author
mohamed82008
committed
Fix constraint support
1 parent f05d8b4 commit 81f30f3

File tree

1 file changed

+30
-29
lines changed

1 file changed

+30
-29
lines changed

src/lobpcg.jl

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -112,22 +112,22 @@ function B_mul_X!(b::Blocks{false}, B, n = 0)
112112
return
113113
end
114114

115-
struct Constraint{T, TA<:AbstractArray{T}, TC}
116-
Y::TA
117-
BY::TA
115+
struct Constraint{T, TVorM<:Union{AbstractVector{T}, AbstractMatrix{T}}, TM<:AbstractMatrix{T}, TC}
116+
Y::TVorM
117+
BY::TVorM
118118
gram_chol::TC
119-
gramYBV::TA # to be used in view
120-
tmp::TA # to be used in view
119+
gramYBV::TM # to be used in view
120+
tmp::TM # to be used in view
121121
end
122122
function Constraint(::Void, B, X)
123-
return Constraint{Void, Matrix{Void}, Void}(Matrix{Void}(0,0), Matrix{Void}(0,0), nothing, Matrix{Void}(0,0), Matrix{Void}(0,0))
123+
return Constraint{Void, Matrix{Void}, Matrix{Void}, Void}(Matrix{Void}(0,0), Matrix{Void}(0,0), nothing, Matrix{Void}(0,0), Matrix{Void}(0,0))
124124
end
125125
function Constraint(Y, B, X)
126126
T = eltype(X)
127127
if B isa Void
128-
B = Y
128+
BY = Y
129129
else
130-
BY = similar(B)
130+
BY = similar(Y)
131131
A_mul_B!(BY, B, Y)
132132
end
133133
gramYBY = Ac_mul_B(Y, BY)
@@ -136,21 +136,22 @@ function Constraint(Y, B, X)
136136
gramYBV = zeros(T, size(Y, 2), size(X, 2))
137137
tmp = similar(gramYBV)
138138

139-
return Constraint(Y, BY, gramYBY_chol, gramYBV, tmp)
139+
return Constraint{eltype(Y), typeof(Y), typeof(gramYBV), typeof(gramYBY_chol)}(Y, BY, gramYBY_chol, gramYBV, tmp)
140140
end
141141

142-
function (constr!::Constraint{Void})(X)
142+
function (constr!::Constraint{Void})(X, X_temp)
143143
nothing
144144
end
145145

146-
function (constr!::Constraint)(X)
146+
function (constr!::Constraint)(X, X_temp)
147147
sizeX = size(X, 2)
148148
sizeY = size(constr!.Y, 2)
149149
gramYBV_view = view(constr!.gramYBV, 1:sizeY, 1:sizeX)
150150
Ac_mul_B!(gramYBV_view, constr!.BY, X)
151151
tmp_view = view(constr!.tmp, 1:sizeY, 1:sizeX)
152-
A_ldiv_B!(tmp_view, gram_chol, gramYBV_view)
153-
A_mul_B!(X, constr!.Y, tmp_view)
152+
A_ldiv_B!(tmp_view, constr!.gram_chol, gramYBV_view)
153+
A_mul_B!(X_temp, constr!.Y, tmp_view)
154+
@inbounds X .= X .- X_temp
154155

155156
nothing
156157
end
@@ -200,9 +201,9 @@ PAP!(BlockGram, PBlocks, n) = Ac_mul_B!(view(BlockGram.PAP, 1:n, 1:n), view(PBlo
200201
XBP!(BlockGram, XBlocks, PBlocks, n) = Ac_mul_B!(view(BlockGram.XAP, :, 1:n), XBlocks.block, view(PBlocks.B_block, :, 1:n))
201202
XBR!(BlockGram, XBlocks, RBlocks, n) = Ac_mul_B!(view(BlockGram.XAR, :, 1:n), XBlocks.block, view(RBlocks.B_block, :, 1:n))
202203
RBP!(BlockGram, RBlocks, PBlocks, n) = Ac_mul_B!(view(BlockGram.RAP, 1:n, 1:n), view(RBlocks.B_block, :, 1:n), view(PBlocks.block, :, 1:n))
203-
XBX!(BlockGram, XBlocks) = Ac_mul_B!(BlockGram.XAX, XBlocks.block, XBlocks.B_block)
204-
RBR!(BlockGram, RBlocks, n) = Ac_mul_B!(view(BlockGram.RAR, 1:n, 1:n), view(RBlocks.block, :, 1:n), view(RBlocks.B_block, :, 1:n))
205-
PBP!(BlockGram, PBlocks, n) = Ac_mul_B!(view(BlockGram.PAP, 1:n, 1:n), view(PBlocks.block, :, 1:n), view(PBlocks.B_block, :, 1:n))
204+
#XBX!(BlockGram, XBlocks) = Ac_mul_B!(BlockGram.XAX, XBlocks.block, XBlocks.B_block)
205+
#RBR!(BlockGram, RBlocks, n) = Ac_mul_B!(view(BlockGram.RAR, 1:n, 1:n), view(RBlocks.block, :, 1:n), view(RBlocks.B_block, :, 1:n))
206+
#PBP!(BlockGram, PBlocks, n) = Ac_mul_B!(view(BlockGram.PAP, 1:n, 1:n), view(PBlocks.block, :, 1:n), view(PBlocks.B_block, :, 1:n))
206207

207208
function I!(G, xr)
208209
@inbounds for j in xr, i in xr
@@ -242,24 +243,24 @@ function (g::BlockGram)(gram, n1::Int, n2::Int, n3::Int, normalized::Bool=true)
242243
if n1 > 0
243244
if normalized
244245
I!(gram, xr)
245-
else
246-
@inbounds gram[xr, xr] .= view(g.XAX, 1:n1, 1:n1)
246+
#else
247+
# @inbounds gram[xr, xr] .= view(g.XAX, 1:n1, 1:n1)
247248
end
248249
end
249250
if n2 > 0
250251
if normalized
251252
I!(gram, rr)
252-
else
253-
@inbounds gram[rr, rr] .= view(g.RAR, 1:n2, 1:n2)
253+
#else
254+
# @inbounds gram[rr, rr] .= view(g.RAR, 1:n2, 1:n2)
254255
end
255256
@inbounds gram[xr, rr] .= view(g.XAR, 1:n1, 1:n2)
256257
@inbounds conj!(transpose!(view(gram, rr, xr), view(g.XAR, 1:n1, 1:n2)))
257258
end
258259
if n3 > 0
259260
if normalized
260261
I!(gram, pr)
261-
else
262-
@inbounds gram[pr, pr] .= view(g.PAP, 1:n3, 1:n3)
262+
#else
263+
# @inbounds gram[pr, pr] .= view(g.PAP, 1:n3, 1:n3)
263264
end
264265
@inbounds gram[rr, pr] .= view(g.RAP, 1:n2, 1:n3)
265266
@inbounds gram[xr, pr] .= view(g.XAP, 1:n1, 1:n3)
@@ -463,10 +464,10 @@ function update_active!(mask, bs::Int, blockPairs...)
463464
return
464465
end
465466

466-
function precond_constr!(block, bs, precond!, constr!)
467+
function precond_constr!(block, temp_block, bs, precond!, constr!)
467468
precond!(view(block, :, 1:bs))
468469
# Constrain the active residual vectors to be B-orthogonal to Y
469-
constr!(view(block, :, 1:bs))
470+
constr!(view(block, :, 1:bs), view(temp_block, :, 1:bs))
470471
return
471472
end
472473
function block_grams_1x1!(iterator)
@@ -525,7 +526,6 @@ function sub_problem!(iterator, sizeX, bs1, bs2)
525526
selectperm!(view(iterator.λperm, 1:subdim), eigf.values, 1:subdim, rev=iterator.largest)
526527
@inbounds iterator.ritz_values[1:sizeX] .= view(eigf.values, view(iterator.λperm, 1:sizeX))
527528
@inbounds iterator.V[1:subdim, 1:sizeX] .= view(eigf.vectors, :, view(iterator.λperm, 1:sizeX))
528-
529529
return
530530
end
531531

@@ -594,7 +594,6 @@ function (iterator::LOBPCGIterator{Generalized})(residualTolerance, log) where {
594594
sizeX = size(iterator.XBlocks.block, 2)
595595
iteration = iterator.iteration[]
596596
if iteration == 1
597-
iterator.constr!(iterator.XBlocks.block)
598597
ortho_AB_mul_X!(iterator.XBlocks, iterator.ortho!, iterator.A, iterator.B)
599598
# Finds gram matrix X'AX
600599
block_grams_1x1!(iterator)
@@ -608,7 +607,7 @@ function (iterator::LOBPCGIterator{Generalized})(residualTolerance, log) where {
608607
# Update active R blocks
609608
update_active!(iterator.activeMask, bs, (iterator.activeRBlocks.block, iterator.RBlocks.block))
610609
# Precondition and constrain the active residual vectors
611-
precond_constr!(iterator.activeRBlocks.block, bs, iterator.precond!, iterator.constr!)
610+
precond_constr!(iterator.activeRBlocks.block, iterator.tempXBlocks.block, bs, iterator.precond!, iterator.constr!)
612611
# Orthonormalizes R[:,1:bs] and finds AR[:,1:bs] and BR[:,1:bs]
613612
ortho_AB_mul_X!(iterator.activeRBlocks, iterator.ortho!, iterator.A, iterator.B, bs)
614613
# Find [X R] A [X R] and [X R]' B [X R]
@@ -628,7 +627,7 @@ function (iterator::LOBPCGIterator{Generalized})(residualTolerance, log) where {
628627
(iterator.activePBlocks.A_block, iterator.PBlocks.A_block),
629628
(iterator.activePBlocks.B_block, iterator.PBlocks.B_block))
630629
# Precondition and constrain the active residual vectors
631-
precond_constr!(iterator.activeRBlocks.block, bs, iterator.precond!, iterator.constr!)
630+
precond_constr!(iterator.activeRBlocks.block, iterator.tempXBlocks.block, bs, iterator.precond!, iterator.constr!)
632631
# Orthonormalizes R[:,1:bs] and finds AR[:,1:bs] and BR[:,1:bs]
633632
ortho_AB_mul_X!(iterator.activeRBlocks, iterator.ortho!, iterator.A, iterator.B, bs)
634633
# Orthonormalizes P and updates AP
@@ -766,16 +765,18 @@ end
766765
function lobpcg!(iterator::LOBPCGIterator; log=false, tol=nothing, maxiter=200, not_zeros=false)
767766
T = eltype(iterator.XBlocks.block)
768767
X = iterator.XBlocks.block
768+
iterator.constr!(iterator.XBlocks.block, iterator.tempXBlocks.block)
769769
if !not_zeros
770770
for j in 1:size(X,2)
771771
if all(x -> x==0, view(X, :, j))
772772
@inbounds X[:,j] .= rand.()
773773
end
774774
end
775+
iterator.constr!(iterator.XBlocks.block, iterator.tempXBlocks.block)
775776
end
776777
n = size(X, 1)
777778
sizeX = size(X, 2)
778-
residualTolerance = (tol isa Void) ? sqrt(eps(real(T))) : tol
779+
residualTolerance = (tol isa Void) ? (eps(real(T)))^(real(T)(4)/10) : tol
779780
iterator.iteration[] = 1
780781
while iterator.iteration[] <= maxiter
781782
state = iterator(residualTolerance, log)

0 commit comments

Comments
 (0)