Skip to content

Commit e2e39bc

Browse files
Sort eigenvalues after LOBPCG (#964)
1 parent de23ab3 commit e2e39bc

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

src/eigen/lobpcg_hyper_impl.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -298,11 +298,21 @@ end
298298
end
299299

300300

301-
function final_retval(X, AX, resid_history, niter, n_matvec)
302-
λ = real(diag(X' * AX))
303-
residuals = AX .- X*Diagonal(λ)
304-
(; λ, X,
305-
residual_norms=[norm(residuals[:, i]) for i = 1:size(residuals, 2)],
301+
function final_retval(X, AX, BX, resid_history, niter, n_matvec)
302+
λ = @views [(X[:, n]'*AX[:, n]) / (X[:, n]'BX[:, n]) for n=1:size(X, 2)]
303+
λ = real(oftype(X[:, 1], λ)) # Offload to GPU if needed
304+
residuals = AX .- BX .* λ'
305+
if !issorted(λ)
306+
p = sortperm(λ)
307+
λ = λ[p]
308+
residuals = residuals[:, p]
309+
X = X[:, p]
310+
AX = AX[:, p]
311+
BX = BX[:, p]
312+
resid_history = resid_history[p, :]
313+
end
314+
(; λ, X, AX, BX,
315+
residual_norms=norm.(eachcol(residuals)),
306316
residual_history=resid_history[:, 1:niter+1], n_matvec)
307317
end
308318

@@ -358,10 +368,11 @@ end
358368
nlocked = 0
359369
niter = 0 # the first iteration is fake
360370
λs = @views [(X[:, n]'*AX[:, n]) / (X[:, n]'BX[:, n]) for n=1:M]
361-
λs = oftype(X[:, 1], λs) # Offload to GPU if needed
371+
λs = real(oftype(X[:, 1], λs)) # Offload to GPU if needed
362372
new_X = X
363373
new_AX = AX
364374
new_BX = BX
375+
# The full_ arrays contain all the vectors, the others only get the active ones
365376
full_X = X
366377
full_AX = AX
367378
full_BX = BX
@@ -435,7 +446,7 @@ end
435446
if nlocked >= n_conv_check # Converged!
436447
X .= new_X # Update the part of X which is still active
437448
AX .= new_AX
438-
return final_retval(full_X, full_AX, resid_history, niter, n_matvec)
449+
return final_retval(full_X, full_AX, full_BX, resid_history, niter, n_matvec)
439450
end
440451
newly_locked = nlocked - prev_nlocked
441452
active = newly_locked+1:size(X,2) # newly active vectors
@@ -524,5 +535,5 @@ end
524535
niter = niter + 1
525536
end
526537

527-
final_retval(full_X, full_AX, resid_history, maxiter, n_matvec)
538+
final_retval(full_X, full_AX, full_BX, resid_history, maxiter, n_matvec)
528539
end

0 commit comments

Comments
 (0)