Skip to content

Commit 85462ce

Browse files
authored
Eliminating Matrix operations in MLMG CG bottom solver if initial vector is zero (#3668)
A matrix multiplication and a few copy operations can be avoided if the input vector is zero. MLMG calls all the the bottom solvers with zeroed `x` vector, and thus the initial residual calculation `b - Ax` is `b`. Furthermore, it also eliminates the memory requirement of storing the initial vector.
1 parent ef38229 commit 85462ce

File tree

2 files changed

+47
-14
lines changed

2 files changed

+47
-14
lines changed

Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ public:
4242
void setMaxIter (int _maxiter) { maxiter = _maxiter; }
4343
[[nodiscard]] int getMaxIter () const { return maxiter; }
4444

45+
46+
/**
47+
* Is the initial guess provided to the solver zero ?
48+
* If so, set this to true.
49+
* The solver will avoid a few operations if this is true.
50+
* Default is false.
51+
*/
52+
void setInitSolnZeroed (bool _sol_zeroed) { initial_vec_zeroed = _sol_zeroed; }
53+
[[nodiscard]] bool getInitSolnZeroed () const { return initial_vec_zeroed; }
54+
4555
void setNGhost(int _nghost) {nghost = IntVect(_nghost);}
4656
[[nodiscard]] int getNGhost() {return nghost[0];}
4757

@@ -62,6 +72,7 @@ private:
6272
int maxiter = 100;
6373
IntVect nghost = IntVect(0);
6474
int iter = -1;
75+
bool initial_vec_zeroed = false;
6576
};
6677

6778
template <typename MF>
@@ -95,21 +106,28 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
95106
p.setVal(RT(0.0)); // Make sure all entries are initialized to avoid errors
96107
r.setVal(RT(0.0));
97108

98-
MF sorig = Lp.make(amrlev, mglev, nghost);
99109
MF rh = Lp.make(amrlev, mglev, nghost);
100110
MF v = Lp.make(amrlev, mglev, nghost);
101111
MF t = Lp.make(amrlev, mglev, nghost);
102112

103-
Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);
113+
114+
MF sorig;
115+
116+
if ( initial_vec_zeroed ) {
117+
r.LocalCopy(rhs,0,0,ncomp,nghost);
118+
} else {
119+
sorig = Lp.make(amrlev, mglev, nghost);
120+
121+
Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);
122+
123+
sorig.LocalCopy(sol,0,0,ncomp,nghost);
124+
sol.setVal(RT(0.0));
125+
}
104126

105127
// Then normalize
106128
Lp.normalize(amrlev, mglev, r);
107-
108-
sorig.LocalCopy(sol,0,0,ncomp,nghost);
109129
rh.LocalCopy (r ,0,0,ncomp,nghost);
110130

111-
sol.setVal(RT(0.0));
112-
113131
RT rnorm = norm_inf(r);
114132
const RT rnorm0 = rnorm;
115133

@@ -238,12 +256,16 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
238256

239257
if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
240258
{
241-
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
259+
if ( !initial_vec_zeroed ) {
260+
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
261+
}
242262
}
243263
else
244264
{
245265
sol.setVal(RT(0.0));
246-
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
266+
if ( !initial_vec_zeroed ) {
267+
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
268+
}
247269
}
248270

249271
return ret;
@@ -260,15 +282,21 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
260282
MF p = Lp.make(amrlev, mglev, sol.nGrowVect());
261283
p.setVal(RT(0.0));
262284

263-
MF sorig = Lp.make(amrlev, mglev, nghost);
264285
MF r = Lp.make(amrlev, mglev, nghost);
265286
MF q = Lp.make(amrlev, mglev, nghost);
266287

267-
sorig.LocalCopy(sol,0,0,ncomp,nghost);
288+
MF sorig;
289+
290+
if ( initial_vec_zeroed ) {
291+
r.LocalCopy(rhs,0,0,ncomp,nghost);
292+
} else {
293+
sorig = Lp.make(amrlev, mglev, nghost);
268294

269-
Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);
295+
Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);
270296

271-
sol.setVal(RT(0.0));
297+
sorig.LocalCopy(sol,0,0,ncomp,nghost);
298+
sol.setVal(RT(0.0));
299+
}
272300

273301
RT rnorm = norm_inf(r);
274302
const RT rnorm0 = rnorm;
@@ -364,12 +392,16 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
364392

365393
if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
366394
{
367-
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
395+
if ( !initial_vec_zeroed ) {
396+
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
397+
}
368398
}
369399
else
370400
{
371401
sol.setVal(RT(0.0));
372-
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
402+
if ( !initial_vec_zeroed ) {
403+
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
404+
}
373405
}
374406

375407
return ret;

Src/LinearSolvers/MLMG/AMReX_MLMG.H

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,6 +1526,7 @@ MLMGT<MF>::bottomSolveWithCG (MF& x, const MF& b, typename MLCGSolverT<MF>::Type
15261526
cg_solver.setSolver(type);
15271527
cg_solver.setVerbose(bottom_verbose);
15281528
cg_solver.setMaxIter(bottom_maxiter);
1529+
cg_solver.setInitSolnZeroed(true);
15291530
if (cf_strategy == CFStrategy::ghostnodes) { cg_solver.setNGhost(linop.getNGrow()); }
15301531

15311532
int ret = cg_solver.solve(x, b, bottom_reltol, bottom_abstol);

0 commit comments

Comments
 (0)