@@ -112,22 +112,22 @@ function B_mul_X!(b::Blocks{false}, B, n = 0)
112
112
return
113
113
end
114
114
115
- struct Constraint{T, TA <: AbstractArray {T} , TC}
116
- Y:: TA
117
- BY:: TA
115
+ struct Constraint{T, TVorM <: Union{AbstractVector{T}, AbstractMatrix{T}} , TM <: AbstractMatrix {T} , TC}
116
+ Y:: TVorM
117
+ BY:: TVorM
118
118
gram_chol:: TC
119
- gramYBV:: TA # to be used in view
120
- tmp:: TA # to be used in view
119
+ gramYBV:: TM # to be used in view
120
+ tmp:: TM # to be used in view
121
121
end
122
122
function Constraint (:: Void , B, X)
123
- return Constraint {Void, Matrix{Void}, Void} (Matrix {Void} (0 ,0 ), Matrix {Void} (0 ,0 ), nothing , Matrix {Void} (0 ,0 ), Matrix {Void} (0 ,0 ))
123
+ return Constraint {Void, Matrix{Void}, Matrix{Void}, Void} (Matrix {Void} (0 ,0 ), Matrix {Void} (0 ,0 ), nothing , Matrix {Void} (0 ,0 ), Matrix {Void} (0 ,0 ))
124
124
end
125
125
function Constraint (Y, B, X)
126
126
T = eltype (X)
127
127
if B isa Void
128
- B = Y
128
+ BY = Y
129
129
else
130
- BY = similar (B )
130
+ BY = similar (Y )
131
131
A_mul_B! (BY, B, Y)
132
132
end
133
133
gramYBY = Ac_mul_B (Y, BY)
@@ -136,21 +136,22 @@ function Constraint(Y, B, X)
136
136
gramYBV = zeros (T, size (Y, 2 ), size (X, 2 ))
137
137
tmp = similar (gramYBV)
138
138
139
- return Constraint (Y, BY, gramYBY_chol, gramYBV, tmp)
139
+ return Constraint {eltype(Y), typeof(Y), typeof(gramYBV), typeof(gramYBY_chol)} (Y, BY, gramYBY_chol, gramYBV, tmp)
140
140
end
141
141
142
- function (constr!:: Constraint{Void} )(X)
142
+ function (constr!:: Constraint{Void} )(X, X_temp )
143
143
nothing
144
144
end
145
145
146
- function (constr!:: Constraint )(X)
146
+ function (constr!:: Constraint )(X, X_temp )
147
147
sizeX = size (X, 2 )
148
148
sizeY = size (constr!. Y, 2 )
149
149
gramYBV_view = view (constr!. gramYBV, 1 : sizeY, 1 : sizeX)
150
150
Ac_mul_B! (gramYBV_view, constr!. BY, X)
151
151
tmp_view = view (constr!. tmp, 1 : sizeY, 1 : sizeX)
152
- A_ldiv_B! (tmp_view, gram_chol, gramYBV_view)
153
- A_mul_B! (X, constr!. Y, tmp_view)
152
+ A_ldiv_B! (tmp_view, constr!. gram_chol, gramYBV_view)
153
+ A_mul_B! (X_temp, constr!. Y, tmp_view)
154
+ @inbounds X .= X .- X_temp
154
155
155
156
nothing
156
157
end
@@ -200,9 +201,9 @@ PAP!(BlockGram, PBlocks, n) = Ac_mul_B!(view(BlockGram.PAP, 1:n, 1:n), view(PBlo
200
201
XBP! (BlockGram, XBlocks, PBlocks, n) = Ac_mul_B! (view (BlockGram. XAP, :, 1 : n), XBlocks. block, view (PBlocks. B_block, :, 1 : n))
201
202
XBR! (BlockGram, XBlocks, RBlocks, n) = Ac_mul_B! (view (BlockGram. XAR, :, 1 : n), XBlocks. block, view (RBlocks. B_block, :, 1 : n))
202
203
RBP! (BlockGram, RBlocks, PBlocks, n) = Ac_mul_B! (view (BlockGram. RAP, 1 : n, 1 : n), view (RBlocks. B_block, :, 1 : n), view (PBlocks. block, :, 1 : n))
203
- XBX! (BlockGram, XBlocks) = Ac_mul_B! (BlockGram. XAX, XBlocks. block, XBlocks. B_block)
204
- RBR! (BlockGram, RBlocks, n) = Ac_mul_B! (view (BlockGram. RAR, 1 : n, 1 : n), view (RBlocks. block, :, 1 : n), view (RBlocks. B_block, :, 1 : n))
205
- PBP! (BlockGram, PBlocks, n) = Ac_mul_B! (view (BlockGram. PAP, 1 : n, 1 : n), view (PBlocks. block, :, 1 : n), view (PBlocks. B_block, :, 1 : n))
204
+ # XBX!(BlockGram, XBlocks) = Ac_mul_B!(BlockGram.XAX, XBlocks.block, XBlocks.B_block)
205
+ # RBR!(BlockGram, RBlocks, n) = Ac_mul_B!(view(BlockGram.RAR, 1:n, 1:n), view(RBlocks.block, :, 1:n), view(RBlocks.B_block, :, 1:n))
206
+ # PBP!(BlockGram, PBlocks, n) = Ac_mul_B!(view(BlockGram.PAP, 1:n, 1:n), view(PBlocks.block, :, 1:n), view(PBlocks.B_block, :, 1:n))
206
207
207
208
function I! (G, xr)
208
209
@inbounds for j in xr, i in xr
@@ -242,24 +243,24 @@ function (g::BlockGram)(gram, n1::Int, n2::Int, n3::Int, normalized::Bool=true)
242
243
if n1 > 0
243
244
if normalized
244
245
I! (gram, xr)
245
- else
246
- @inbounds gram[xr, xr] .= view (g. XAX, 1 : n1, 1 : n1)
246
+ # else
247
+ # @inbounds gram[xr, xr] .= view(g.XAX, 1:n1, 1:n1)
247
248
end
248
249
end
249
250
if n2 > 0
250
251
if normalized
251
252
I! (gram, rr)
252
- else
253
- @inbounds gram[rr, rr] .= view (g. RAR, 1 : n2, 1 : n2)
253
+ # else
254
+ # @inbounds gram[rr, rr] .= view(g.RAR, 1:n2, 1:n2)
254
255
end
255
256
@inbounds gram[xr, rr] .= view (g. XAR, 1 : n1, 1 : n2)
256
257
@inbounds conj! (transpose! (view (gram, rr, xr), view (g. XAR, 1 : n1, 1 : n2)))
257
258
end
258
259
if n3 > 0
259
260
if normalized
260
261
I! (gram, pr)
261
- else
262
- @inbounds gram[pr, pr] .= view (g. PAP, 1 : n3, 1 : n3)
262
+ # else
263
+ # @inbounds gram[pr, pr] .= view(g.PAP, 1:n3, 1:n3)
263
264
end
264
265
@inbounds gram[rr, pr] .= view (g. RAP, 1 : n2, 1 : n3)
265
266
@inbounds gram[xr, pr] .= view (g. XAP, 1 : n1, 1 : n3)
@@ -463,10 +464,10 @@ function update_active!(mask, bs::Int, blockPairs...)
463
464
return
464
465
end
465
466
466
- function precond_constr! (block, bs, precond!, constr!)
467
+ function precond_constr! (block, temp_block, bs, precond!, constr!)
467
468
precond! (view (block, :, 1 : bs))
468
469
# Constrain the active residual vectors to be B-orthogonal to Y
469
- constr! (view (block, :, 1 : bs))
470
+ constr! (view (block, :, 1 : bs), view (temp_block, :, 1 : bs) )
470
471
return
471
472
end
472
473
function block_grams_1x1! (iterator)
@@ -525,7 +526,6 @@ function sub_problem!(iterator, sizeX, bs1, bs2)
525
526
selectperm! (view (iterator. λperm, 1 : subdim), eigf. values, 1 : subdim, rev= iterator. largest)
526
527
@inbounds iterator. ritz_values[1 : sizeX] .= view (eigf. values, view (iterator. λperm, 1 : sizeX))
527
528
@inbounds iterator. V[1 : subdim, 1 : sizeX] .= view (eigf. vectors, :, view (iterator. λperm, 1 : sizeX))
528
-
529
529
return
530
530
end
531
531
@@ -594,7 +594,6 @@ function (iterator::LOBPCGIterator{Generalized})(residualTolerance, log) where {
594
594
sizeX = size (iterator. XBlocks. block, 2 )
595
595
iteration = iterator. iteration[]
596
596
if iteration == 1
597
- iterator. constr! (iterator. XBlocks. block)
598
597
ortho_AB_mul_X! (iterator. XBlocks, iterator. ortho!, iterator. A, iterator. B)
599
598
# Finds gram matrix X'AX
600
599
block_grams_1x1! (iterator)
@@ -608,7 +607,7 @@ function (iterator::LOBPCGIterator{Generalized})(residualTolerance, log) where {
608
607
# Update active R blocks
609
608
update_active! (iterator. activeMask, bs, (iterator. activeRBlocks. block, iterator. RBlocks. block))
610
609
# Precondition and constrain the active residual vectors
611
- precond_constr! (iterator. activeRBlocks. block, bs, iterator. precond!, iterator. constr!)
610
+ precond_constr! (iterator. activeRBlocks. block, iterator . tempXBlocks . block, bs, iterator. precond!, iterator. constr!)
612
611
# Orthonormalizes R[:,1:bs] and finds AR[:,1:bs] and BR[:,1:bs]
613
612
ortho_AB_mul_X! (iterator. activeRBlocks, iterator. ortho!, iterator. A, iterator. B, bs)
614
613
# Find [X R] A [X R] and [X R]' B [X R]
@@ -628,7 +627,7 @@ function (iterator::LOBPCGIterator{Generalized})(residualTolerance, log) where {
628
627
(iterator. activePBlocks. A_block, iterator. PBlocks. A_block),
629
628
(iterator. activePBlocks. B_block, iterator. PBlocks. B_block))
630
629
# Precondition and constrain the active residual vectors
631
- precond_constr! (iterator. activeRBlocks. block, bs, iterator. precond!, iterator. constr!)
630
+ precond_constr! (iterator. activeRBlocks. block, iterator . tempXBlocks . block, bs, iterator. precond!, iterator. constr!)
632
631
# Orthonormalizes R[:,1:bs] and finds AR[:,1:bs] and BR[:,1:bs]
633
632
ortho_AB_mul_X! (iterator. activeRBlocks, iterator. ortho!, iterator. A, iterator. B, bs)
634
633
# Orthonormalizes P and updates AP
@@ -766,16 +765,18 @@ end
766
765
function lobpcg! (iterator:: LOBPCGIterator ; log= false , tol= nothing , maxiter= 200 , not_zeros= false )
767
766
T = eltype (iterator. XBlocks. block)
768
767
X = iterator. XBlocks. block
768
+ iterator. constr! (iterator. XBlocks. block, iterator. tempXBlocks. block)
769
769
if ! not_zeros
770
770
for j in 1 : size (X,2 )
771
771
if all (x -> x== 0 , view (X, :, j))
772
772
@inbounds X[:,j] .= rand .()
773
773
end
774
774
end
775
+ iterator. constr! (iterator. XBlocks. block, iterator. tempXBlocks. block)
775
776
end
776
777
n = size (X, 1 )
777
778
sizeX = size (X, 2 )
778
- residualTolerance = (tol isa Void) ? sqrt (eps (real (T))) : tol
779
+ residualTolerance = (tol isa Void) ? (eps (real (T))) ^ ( real (T)( 4 ) / 10 ) : tol
779
780
iterator. iteration[] = 1
780
781
while iterator. iteration[] <= maxiter
781
782
state = iterator (residualTolerance, log)
0 commit comments