Skip to content

Commit bbe3bb3

Browse files
authored
Support Dirichlet BC in z-direction in hybrid FFT Poisson solver (#4503)
The change was provided by JBB.
1 parent 756679f commit bbe3bb3

File tree

3 files changed

+49
-44
lines changed

3 files changed

+49
-44
lines changed

Docs/sphinx_documentation/source/FFT.rst

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,10 @@ boundaries.
188188
FFT::PoissonOpenBC openbc_solver(geom, soln.ixType(), IntVect(ng));
189189
openbc_solver.solve(soln, rhs);
190190

191-
:cpp:`FFT::PoissonHybrid` is a 3D only solver that supports periodic
192-
boundaries in the first two dimensions and Neumann boundary in the last
193-
dimension. The last dimension is solved with a tridiagonal solver that can
194-
support non-uniform cell size in the z-direction. For most applications,
195-
:cpp:`FFT::Poisson` should be used.
191+
:cpp:`FFT::PoissonHybrid` is a 3D only solver that supports Dirichlet and
192+
Neumann boundary in the last dimension. The last dimension is solved with a
193+
tridiagonal solver that can support non-uniform cell size in the
194+
z-direction. For most applications, :cpp:`FFT::Poisson` should be used.
196195

197196
Similar to :cpp:`FFT::R2C`, the Poisson solvers should be cached for reuse,
198197
and one might need to use :cpp:`std::unique_ptr<FFT::Poisson<MultiFab>>`

Src/FFT/AMReX_FFT_Poisson.H

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ private:
9898

9999
/**
100100
* \brief 3D Poisson solver for periodic, Dirichlet & Neumann boundaries in
101-
* the first two dimensions, and Neumann in the last dimension. The last
102-
* dimension could have non-uniform mesh.
101+
* the first two dimensions, and Dirichlet & Neumann in the last
102+
* dimension. The last dimension could have non-uniform mesh.
103103
*/
104104
template <typename MF = MultiFab>
105105
class PoissonHybrid
@@ -126,6 +126,8 @@ public:
126126
bc[idim].second != Boundary::periodic));
127127
}
128128
}
129+
AMREX_ALWAYS_ASSERT(bc[2].first != Boundary::periodic &&
130+
bc[2].second != Boundary::periodic);
129131
Info info{};
130132
info.setTwoDMode(true);
131133
if (periodic_xy) {
@@ -138,24 +140,6 @@ public:
138140
build_spmf();
139141
}
140142

141-
template <typename FA=MF, std::enable_if_t<IsFabArray_v<FA>,int> = 0>
142-
explicit PoissonHybrid (Geometry const& geom)
143-
: m_geom(geom),
144-
m_bc{AMREX_D_DECL(std::make_pair(Boundary::periodic,Boundary::periodic),
145-
std::make_pair(Boundary::periodic,Boundary::periodic),
146-
std::make_pair(Boundary::even,Boundary::even))},
147-
m_r2c(std::make_unique<R2C<typename MF::value_type>>
148-
(geom.Domain(), Info().setTwoDMode(true)))
149-
{
150-
#if (AMREX_SPACEDIM == 3)
151-
AMREX_ALWAYS_ASSERT(geom.isPeriodic(0) && geom.isPeriodic(1));
152-
#else
153-
amrex::Abort("FFT::PoissonHybrid: 1D & 2D todo");
154-
return;
155-
#endif
156-
build_spmf();
157-
}
158-
159143
/*
160144
* \brief Solve del dot grad soln = rhs
161145
*
@@ -329,7 +313,8 @@ namespace fft_poisson_detail {
329313
[[nodiscard]] AMREX_GPU_DEVICE AMREX_FORCE_INLINE
330314
T operator() (int, int, int k) const
331315
{
332-
return T(2.0) /(m_dz[k]*(m_dz[k]+m_dz[k-1]));
316+
return (k > 0) ? T(2.0) / (m_dz[k]*(m_dz[k]+m_dz[k-1]))
317+
: T(1.0) / (m_dz[k]* m_dz[k]);
333318
}
334319
T const* m_dz;
335320
};
@@ -339,9 +324,11 @@ namespace fft_poisson_detail {
339324
[[nodiscard]] AMREX_GPU_DEVICE AMREX_FORCE_INLINE
340325
T operator() (int, int, int k) const
341326
{
342-
return T(2.0) /(m_dz[k]*(m_dz[k]+m_dz[k+1]));
327+
return (k < m_size-1) ? T(2.0) / (m_dz[k]*(m_dz[k]+m_dz[k+1]))
328+
: T(1.0) / (m_dz[k]* m_dz[k]);
343329
}
344330
T const* m_dz;
331+
int m_size;
345332
};
346333
}
347334

@@ -421,7 +408,7 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs, Gpu::DeviceVector<T> con
421408
auto const* pdz = dz.dataPtr();
422409
solve(soln, rhs,
423410
fft_poisson_detail::TriA<T>{pdz},
424-
fft_poisson_detail::TriC<T>{pdz});
411+
fft_poisson_detail::TriC<T>{pdz,int(dz.size())});
425412
}
426413

427414
template <typename MF>
@@ -438,7 +425,7 @@ void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs, Vector<T> const& dz)
438425
#endif
439426
solve(soln, rhs,
440427
fft_poisson_detail::TriA<T>{pdz},
441-
fft_poisson_detail::TriC<T>{pdz});
428+
fft_poisson_detail::TriC<T>{pdz,int(dz.size())});
442429
}
443430

444431
template <typename MF>
@@ -517,7 +504,10 @@ void PoissonHybrid<MF>::solve_z (FA& spmf, TRIA const& tria, TRIC const& tric)
517504
}
518505
}
519506

520-
bool has_dirichlet = (offset[0] != T(0)) || (offset[1] != T(0));
507+
bool zlo_neumann = m_bc[2].first == Boundary::even;
508+
bool zhi_neumann = m_bc[2].second == Boundary::even;
509+
bool is_singular = (offset[0] == T(0)) && (offset[1] == T(0))
510+
&& zlo_neumann && zhi_neumann;
521511

522512
auto nz = m_geom.Domain().length(2);
523513

@@ -545,18 +535,26 @@ void PoissonHybrid<MF>::solve_z (FA& spmf, TRIA const& tria, TRIC const& tric)
545535
T k2 = dxfac * (std::cos(a)-T(1))
546536
+ dyfac * (std::cos(b)-T(1));
547537

548-
// Tridiagonal solve with homogeneous Neumann
538+
// Tridiagonal solve
549539
for(int k=0; k < nz; k++) {
550540
if(k==0) {
551541
ald(i,j,k) = T(0.);
552542
cud(i,j,k) = tric(i,j,k);
553-
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
543+
if (zlo_neumann) {
544+
bd(i,j,k) = k2 - cud(i,j,k);
545+
} else {
546+
bd(i,j,k) = k2 - cud(i,j,k) - T(2.0)*tria(i,j,k);
547+
}
554548
} else if (k == nz-1) {
555549
ald(i,j,k) = tria(i,j,k);
556550
cud(i,j,k) = T(0.);
557-
bd(i,j,k) = k2 -ald(i,j,k)-cud(i,j,k);
558-
if (i == 0 && j == 0 && !has_dirichlet) {
559-
bd(i,j,k) *= T(2.0);
551+
if (zhi_neumann) {
552+
bd(i,j,k) = k2 - ald(i,j,k);
553+
if (i == 0 && j == 0 && is_singular) {
554+
bd(i,j,k) *= T(2.0);
555+
}
556+
} else {
557+
bd(i,j,k) = k2 - ald(i,j,k) - T(2.0)*tric(i,j,k);
560558
}
561559
} else {
562560
ald(i,j,k) = tria(i,j,k);
@@ -600,18 +598,26 @@ void PoissonHybrid<MF>::solve_z (FA& spmf, TRIA const& tria, TRIC const& tric)
600598
T k2 = dxfac * (std::cos(a)-T(1))
601599
+ dyfac * (std::cos(b)-T(1));
602600

603-
// Tridiagonal solve with homogeneous Neumann
601+
// Tridiagonal solve
604602
for(int k=0; k < nz; k++) {
605603
if(k==0) {
606604
ald[k] = T(0.);
607605
cud[k] = tric(i,j,k);
608-
bd[k] = k2 -ald[k]-cud[k];
606+
if (zlo_neumann) {
607+
bd[k] = k2 - cud[k];
608+
} else {
609+
bd[k] = k2 - cud[k] - T(2.0)*tria(i,j,k);
610+
}
609611
} else if (k == nz-1) {
610612
ald[k] = tria(i,j,k);
611613
cud[k] = T(0.);
612-
bd[k] = k2 -ald[k]-cud[k];
613-
if (i == 0 && j == 0 && !has_dirichlet) {
614-
bd[k] *= T(2.0);
614+
if (zhi_neumann) {
615+
bd[k] = k2 - ald[k];
616+
if (i == 0 && j == 0 && is_singular) {
617+
bd[k] *= T(2.0);
618+
}
619+
} else {
620+
bd[k] = k2 - ald[k] - T(2.0)*tric(i,j,k);
615621
}
616622
} else {
617623
ald[k] = tria(i,j,k);

Tests/FFT/Poisson/main.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ void make_rhs (MultiFab& rhs, Geometry const& geom,
4242
fft_bc[idim].second == FFT::Boundary::odd) {
4343
r *= std::sin(x*1.5_rt*fac[idim]);
4444
} else if (fft_bc[idim].first == FFT::Boundary::odd &&
45-
fft_bc[idim].second == FFT::Boundary::even) {
45+
fft_bc[idim].second == FFT::Boundary::even) {
4646
r *= std::sin(x*0.75_rt*fac[idim]);
4747
} else if (fft_bc[idim].first == FFT::Boundary::even &&
4848
fft_bc[idim].second == FFT::Boundary::odd) {
@@ -209,12 +209,12 @@ int main (int argc, char* argv[])
209209
amrex::Print() << " Testing PoissonHybrid\n";
210210

211211
icase = 0;
212+
for (int zcase = 1; zcase < ncasesz; ++zcase) { // skip periodic z-direction
212213
for (int ycase = 0; ycase < ncasesy; ++ycase) {
213214
for (int xcase = 0; xcase < ncases ; ++xcase) {
214215
++icase;
215216
Array<std::pair<FFT::Boundary,FFT::Boundary>,AMREX_SPACEDIM>
216-
fft_bc{bcs[xcase], bcs[ycase],
217-
std::make_pair(FFT::Boundary::even,FFT::Boundary::even)};
217+
fft_bc{bcs[xcase], bcs[ycase], bcs[zcase]};
218218
amrex::Print() << " (" << icase << ") Testing (";
219219
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
220220
amrex::Print() << "(" << getEnumNameString(fft_bc[idim].first)
@@ -244,7 +244,7 @@ int main (int argc, char* argv[])
244244
auto eps = 1.e-11;
245245
#endif
246246
AMREX_ALWAYS_ASSERT(rnorm < eps*bnorm);
247-
}}
247+
}}}
248248
#endif
249249
}
250250

0 commit comments

Comments
 (0)