@@ -298,11 +298,21 @@ end
298298end
299299
300300
301- function final_retval (X, AX, resid_history, niter, n_matvec)
302- λ = real (diag (X' * AX))
303- residuals = AX .- X* Diagonal (λ)
304- (; λ, X,
305- residual_norms= [norm (residuals[:, i]) for i = 1 : size (residuals, 2 )],
301+ function final_retval (X, AX, BX, resid_history, niter, n_matvec)
302+ λ = @views [(X[:, n]' * AX[:, n]) / (X[:, n]' BX[:, n]) for n= 1 : size (X, 2 )]
303+ λ = real (oftype (X[:, 1 ], λ)) # Offload to GPU if needed
304+ residuals = AX .- BX .* λ'
305+ if ! issorted (λ)
306+ p = sortperm (λ)
307+ λ = λ[p]
308+ residuals = residuals[:, p]
309+ X = X[:, p]
310+ AX = AX[:, p]
311+ BX = BX[:, p]
312+ resid_history = resid_history[p, :]
313+ end
314+ (; λ, X, AX, BX,
315+ residual_norms= norm .(eachcol (residuals)),
306316 residual_history= resid_history[:, 1 : niter+ 1 ], n_matvec)
307317end
308318
@@ -358,10 +368,11 @@ end
358368 nlocked = 0
359369 niter = 0 # the first iteration is fake
360370 λs = @views [(X[:, n]' * AX[:, n]) / (X[:, n]' BX[:, n]) for n= 1 : M]
361- λs = oftype (X[:, 1 ], λs) # Offload to GPU if needed
371+ λs = real ( oftype (X[:, 1 ], λs) ) # Offload to GPU if needed
362372 new_X = X
363373 new_AX = AX
364374 new_BX = BX
375+ # The full_ arrays contain all the vectors, the others only get the active ones
365376 full_X = X
366377 full_AX = AX
367378 full_BX = BX
435446 if nlocked >= n_conv_check # Converged!
436447 X .= new_X # Update the part of X which is still active
437448 AX .= new_AX
438- return final_retval (full_X, full_AX, resid_history, niter, n_matvec)
449+ return final_retval (full_X, full_AX, full_BX, resid_history, niter, n_matvec)
439450 end
440451 newly_locked = nlocked - prev_nlocked
441452 active = newly_locked+ 1 : size (X,2 ) # newly active vectors
524535 niter = niter + 1
525536 end
526537
527- final_retval (full_X, full_AX, resid_history, maxiter, n_matvec)
538+ final_retval (full_X, full_AX, full_BX, resid_history, maxiter, n_matvec)
528539end
0 commit comments