Skip to content

Commit 75571e2

Browse files
authored
MLMG: Use free functions instead of MF member functions (#3681)
Note that the use of unqualified functions (e.g., setVal instead of amrex::setVal) is intentional. With ADL, these calls in MLMG could work with user defined data.
1 parent 3407e87 commit 75571e2

File tree

3 files changed

+199
-160
lines changed

3 files changed

+199
-160
lines changed

Src/LinearSolvers/MLMG/AMReX_MLCGSolver.H

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ class MLCGSolverT
1212
{
1313
public:
1414

15-
using FAB = typename MF::fab_type;
16-
using RT = typename MF::value_type;
15+
using FAB = typename MLLinOpT<MF>::FAB;
16+
using RT = typename MLLinOpT<MF>::RT;
1717

1818
enum struct Type { BiCGStab, CG };
1919

@@ -99,12 +99,12 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
9999
{
100100
BL_PROFILE("MLCGSolver::bicgstab");
101101

102-
const int ncomp = sol.nComp();
102+
const int ncomp = nComp(sol);
103103

104-
MF p = Lp.make(amrlev, mglev, sol.nGrowVect());
105-
MF r = Lp.make(amrlev, mglev, sol.nGrowVect());
106-
p.setVal(RT(0.0)); // Make sure all entries are initialized to avoid errors
107-
r.setVal(RT(0.0));
104+
MF p = Lp.make(amrlev, mglev, nGrowVect(sol));
105+
MF r = Lp.make(amrlev, mglev, nGrowVect(sol));
106+
setVal(p, RT(0.0)); // Make sure all entries are initialized to avoid errors
107+
setVal(r, RT(0.0));
108108

109109
MF rh = Lp.make(amrlev, mglev, nghost);
110110
MF v = Lp.make(amrlev, mglev, nghost);
@@ -114,19 +114,19 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
114114
MF sorig;
115115

116116
if ( initial_vec_zeroed ) {
117-
r.LocalCopy(rhs,0,0,ncomp,nghost);
117+
LocalCopy(r,rhs,0,0,ncomp,nghost);
118118
} else {
119119
sorig = Lp.make(amrlev, mglev, nghost);
120120

121121
Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);
122122

123-
sorig.LocalCopy(sol,0,0,ncomp,nghost);
124-
sol.setVal(RT(0.0));
123+
LocalCopy(sorig,sol,0,0,ncomp,nghost);
124+
setVal(sol, RT(0.0));
125125
}
126126

127127
// Then normalize
128128
Lp.normalize(amrlev, mglev, r);
129-
rh.LocalCopy (r ,0,0,ncomp,nghost);
129+
LocalCopy(rh, r, 0,0,ncomp,nghost);
130130

131131
RT rnorm = norm_inf(r);
132132
const RT rnorm0 = rnorm;
@@ -159,13 +159,13 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
159159
}
160160
if ( iter == 1 )
161161
{
162-
p.LocalCopy(r,0,0,ncomp,nghost);
162+
LocalCopy(p,r,0,0,ncomp,nghost);
163163
}
164164
else
165165
{
166166
const RT beta = (rho/rho_1)*(alpha/omega);
167-
MF::Saxpy(p, -omega, v, 0, 0, ncomp, nghost); // p += -omega*v
168-
MF::Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta*p
167+
Saxpy(p, -omega, v, 0, 0, ncomp, nghost); // p += -omega*v
168+
Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta*p
169169
}
170170
Lp.apply(amrlev, mglev, v, p, MLLinOpT<MF>::BCMode::Homogeneous, MLLinOpT<MF>::StateMode::Correction);
171171
Lp.normalize(amrlev, mglev, v);
@@ -179,8 +179,8 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
179179
{
180180
ret = 2; break;
181181
}
182-
MF::Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
183-
MF::Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r += -alpha * v
182+
Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
183+
Saxpy(r, -alpha, v, 0, 0, ncomp, nghost); // r += -alpha * v
184184

185185
rnorm = norm_inf(r);
186186
rnorm = norm_inf(r);
@@ -216,8 +216,8 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
216216
{
217217
ret = 3; break;
218218
}
219-
MF::Saxpy(sol, omega, r, 0, 0, ncomp, nghost); // sol += omega * r
220-
MF::Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r += -omega * t
219+
Saxpy(sol, omega, r, 0, 0, ncomp, nghost); // sol += omega * r
220+
Saxpy(r, -omega, t, 0, 0, ncomp, nghost); // r += -omega * t
221221

222222
rnorm = norm_inf(r);
223223

@@ -257,14 +257,14 @@ MLCGSolverT<MF>::solve_bicgstab (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
257257
if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
258258
{
259259
if ( !initial_vec_zeroed ) {
260-
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
260+
LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
261261
}
262262
}
263263
else
264264
{
265-
sol.setVal(RT(0.0));
265+
setVal(sol, RT(0.0));
266266
if ( !initial_vec_zeroed ) {
267-
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
267+
LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
268268
}
269269
}
270270

@@ -277,25 +277,25 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
277277
{
278278
BL_PROFILE("MLCGSolver::cg");
279279

280-
const int ncomp = sol.nComp();
280+
const int ncomp = nComp(sol);
281281

282-
MF p = Lp.make(amrlev, mglev, sol.nGrowVect());
283-
p.setVal(RT(0.0));
282+
MF p = Lp.make(amrlev, mglev, nGrowVect(sol));
283+
setVal(p, RT(0.0));
284284

285285
MF r = Lp.make(amrlev, mglev, nghost);
286286
MF q = Lp.make(amrlev, mglev, nghost);
287287

288288
MF sorig;
289289

290290
if ( initial_vec_zeroed ) {
291-
r.LocalCopy(rhs,0,0,ncomp,nghost);
291+
LocalCopy(r,rhs,0,0,ncomp,nghost);
292292
} else {
293293
sorig = Lp.make(amrlev, mglev, nghost);
294294

295295
Lp.correctionResidual(amrlev, mglev, r, sol, rhs, MLLinOpT<MF>::BCMode::Homogeneous);
296296

297-
sorig.LocalCopy(sol,0,0,ncomp,nghost);
298-
sol.setVal(RT(0.0));
297+
LocalCopy(sorig,sol,0,0,ncomp,nghost);
298+
setVal(sol, RT(0.0));
299299
}
300300

301301
RT rnorm = norm_inf(r);
@@ -330,12 +330,12 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
330330
}
331331
if (iter == 1)
332332
{
333-
p.LocalCopy(r,0,0,ncomp,nghost);
333+
LocalCopy(p,r,0,0,ncomp,nghost);
334334
}
335335
else
336336
{
337337
RT beta = rho/rho_1;
338-
MF::Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta * p
338+
Xpay(p, beta, r, 0, 0, ncomp, nghost); // p = r + beta * p
339339
}
340340
Lp.apply(amrlev, mglev, q, p, MLLinOpT<MF>::BCMode::Homogeneous, MLLinOpT<MF>::StateMode::Correction);
341341

@@ -357,8 +357,8 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
357357
<< " rho " << rho
358358
<< " alpha " << alpha << '\n';
359359
}
360-
MF::Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
361-
MF::Saxpy(r, -alpha, q, 0, 0, ncomp, nghost); // r += -alpha * q
360+
Saxpy(sol, alpha, p, 0, 0, ncomp, nghost); // sol += alpha * p
361+
Saxpy(r, -alpha, q, 0, 0, ncomp, nghost); // r += -alpha * q
362362
rnorm = norm_inf(r);
363363

364364
if ( verbose > 2 )
@@ -393,14 +393,14 @@ MLCGSolverT<MF>::solve_cg (MF& sol, const MF& rhs, RT eps_rel, RT eps_abs)
393393
if ( ( ret == 0 || ret == 8 ) && (rnorm < rnorm0) )
394394
{
395395
if ( !initial_vec_zeroed ) {
396-
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
396+
LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
397397
}
398398
}
399399
else
400400
{
401-
sol.setVal(RT(0.0));
401+
setVal(sol, RT(0.0));
402402
if ( !initial_vec_zeroed ) {
403-
sol.LocalAdd(sorig, 0, 0, ncomp, nghost);
403+
LocalAdd(sol, sorig, 0, 0, ncomp, nghost);
404404
}
405405
}
406406

@@ -422,8 +422,8 @@ template <typename MF>
422422
auto
423423
MLCGSolverT<MF>::norm_inf (const MF& res, bool local) -> RT
424424
{
425-
int ncomp = res.nComp();
426-
RT result = res.norminf(0,ncomp,IntVect(0),true);
425+
int ncomp = nComp(res);
426+
RT result = norminf(res,0,ncomp,IntVect(0),true);
427427
if (!local) {
428428
BL_PROFILE("MLCGSolver::ParallelAllReduce");
429429
ParallelAllReduce::Max(result, Lp.BottomCommunicator());

Src/LinearSolvers/MLMG/AMReX_MLLinOp.H

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ struct LinOpEnumType
8585
enum struct Location { FaceCenter, FaceCentroid, CellCenter, CellCentroid };
8686
};
8787

88+
template <typename T, class Enable = void> struct LinOpData {};
89+
//
90+
template <typename T>
91+
struct LinOpData <T, std::enable_if_t<IsMultiFabLike_v<T> > >
92+
{
93+
using fab_type = typename T::fab_type;
94+
using value_type = typename T::value_type;
95+
};
96+
8897
template <typename T> class MLMGT;
8998
template <typename T> class MLCGSolverT;
9099
template <typename T> class MLPoissonT;
@@ -100,8 +109,8 @@ public:
100109
template <typename T> friend class MLPoissonT;
101110
template <typename T> friend class MLABecLaplacianT;
102111

103-
using FAB = typename MF::fab_type;
104-
using RT = typename MF::value_type;
112+
using FAB = typename LinOpData<MF>::fab_type;
113+
using RT = typename LinOpData<MF>::value_type;
105114

106115
using BCType = LinOpBCType;
107116
using BCMode = LinOpEnumType::BCMode;
@@ -1375,53 +1384,81 @@ template <typename MF>
13751384
void
13761385
MLLinOpT<MF>::make (Vector<Vector<MF> >& mf, IntVect const& ng) const
13771386
{
1378-
mf.clear();
1379-
mf.resize(m_num_amr_levels);
1380-
for (int alev = 0; alev < m_num_amr_levels; ++alev) {
1381-
mf[alev].resize(m_num_mg_levels[alev]);
1382-
for (int mlev = 0; mlev < m_num_mg_levels[alev]; ++mlev) {
1383-
mf[alev][mlev] = make(alev, mlev, ng);
1387+
if constexpr (IsMultiFabLike_v<MF>) {
1388+
mf.clear();
1389+
mf.resize(m_num_amr_levels);
1390+
for (int alev = 0; alev < m_num_amr_levels; ++alev) {
1391+
mf[alev].resize(m_num_mg_levels[alev]);
1392+
for (int mlev = 0; mlev < m_num_mg_levels[alev]; ++mlev) {
1393+
mf[alev][mlev] = make(alev, mlev, ng);
1394+
}
13841395
}
1396+
} else {
1397+
amrex::ignore_unused(mf, ng);
1398+
amrex::Abort("MLLinOpT::make: how did we get here?");
13851399
}
13861400
}
13871401

13881402
template <typename MF>
13891403
MF
13901404
MLLinOpT<MF>::make (int amrlev, int mglev, IntVect const& ng) const
13911405
{
1392-
return MF(amrex::convert(m_grids[amrlev][mglev], m_ixtype),
1393-
m_dmap[amrlev][mglev], getNComp(), ng, MFInfo(),
1394-
*m_factory[amrlev][mglev]);
1406+
if constexpr (IsMultiFabLike_v<MF>) {
1407+
return MF(amrex::convert(m_grids[amrlev][mglev], m_ixtype),
1408+
m_dmap[amrlev][mglev], getNComp(), ng, MFInfo(),
1409+
*m_factory[amrlev][mglev]);
1410+
} else {
1411+
amrex::ignore_unused(amrlev, mglev, ng);
1412+
amrex::Abort("MLLinOpT::make: how did we get here?");
1413+
return {};
1414+
}
13951415
}
13961416

13971417
template <typename MF>
13981418
MF
13991419
MLLinOpT<MF>::makeAlias (MF const& mf) const
14001420
{
1401-
return MF(mf, amrex::make_alias, 0, mf.nComp());
1421+
if constexpr (IsMultiFabLike_v<MF>) {
1422+
return MF(mf, amrex::make_alias, 0, mf.nComp());
1423+
} else {
1424+
amrex::ignore_unused(mf);
1425+
amrex::Abort("MLLinOpT::makeAlias: how did we get here?");
1426+
return {};
1427+
}
14021428
}
14031429

14041430
template <typename MF>
14051431
MF
14061432
MLLinOpT<MF>::makeCoarseMG (int amrlev, int mglev, IntVect const& ng) const
14071433
{
1408-
BoxArray cba = m_grids[amrlev][mglev];
1409-
IntVect ratio = (amrlev > 0) ? IntVect(2) : mg_coarsen_ratio_vec[mglev];
1410-
cba.coarsen(ratio);
1411-
cba.convert(m_ixtype);
1412-
return MF(cba, m_dmap[amrlev][mglev], getNComp(), ng);
1413-
1434+
if constexpr (IsMultiFabLike_v<MF>) {
1435+
BoxArray cba = m_grids[amrlev][mglev];
1436+
IntVect ratio = (amrlev > 0) ? IntVect(2) : mg_coarsen_ratio_vec[mglev];
1437+
cba.coarsen(ratio);
1438+
cba.convert(m_ixtype);
1439+
return MF(cba, m_dmap[amrlev][mglev], getNComp(), ng);
1440+
} else {
1441+
amrex::ignore_unused(amrlev, mglev, ng);
1442+
amrex::Abort("MLLinOpT::makeCoarseMG: how did we get here?");
1443+
return {};
1444+
}
14141445
}
14151446

14161447
template <typename MF>
14171448
MF
14181449
MLLinOpT<MF>::makeCoarseAmr (int famrlev, IntVect const& ng) const
14191450
{
1420-
BoxArray cba = m_grids[famrlev][0];
1421-
IntVect ratio(AMRRefRatio(famrlev-1));
1422-
cba.coarsen(ratio);
1423-
cba.convert(m_ixtype);
1424-
return MF(cba, m_dmap[famrlev][0], getNComp(), ng);
1451+
if constexpr (IsMultiFabLike_v<MF>) {
1452+
BoxArray cba = m_grids[famrlev][0];
1453+
IntVect ratio(AMRRefRatio(famrlev-1));
1454+
cba.coarsen(ratio);
1455+
cba.convert(m_ixtype);
1456+
return MF(cba, m_dmap[famrlev][0], getNComp(), ng);
1457+
} else {
1458+
amrex::ignore_unused(famrlev, ng);
1459+
amrex::Abort("MLLinOpT::makeCoarseAmr: how did we get here?");
1460+
return {};
1461+
}
14251462
}
14261463

14271464
template <typename MF>

0 commit comments

Comments
 (0)