Skip to content

Commit 65adc9d

Browse files
committed
Change over to new AMReX FFT
1 parent e5cc076 commit 65adc9d

File tree

2 files changed

+37
-182
lines changed

2 files changed

+37
-182
lines changed

source/gravitational_waves.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,13 @@ void GravitationalWaves::ComputeSpectrum(
6161
double L = sim->L * static_cast<double>(zero_padding);
6262

6363
const LevelData &ld = sim->grid_new[lev];
64-
// const amrex::BoxArray &ba = ld.boxArray();
65-
// const amrex::DistributionMapping &dm = ld.DistributionMap();
66-
6764
amrex::MultiFab du_real[6];
6865
amrex::MultiFab du_imag[6];
6966
const int mat[3][3] = {{0, 1, 2}, {1, 3, 4}, {2, 4, 5}};
7067
int comps[6];
7168
modifier->SelectComponents(comps);
7269

7370
for (int i = 0; i < 6; ++i) {
74-
// du_real[i].define(ba, dm, 1, 0);
75-
// du_imag[i].define(ba, dm, 1, 0);
7671
utils::Fft(ld, comps[i] + idx_offset, du_real[i], du_imag[i],
7772
sim->geom[lev], false, zero_padding);
7873
}

source/utils/fft.h

Lines changed: 37 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
#include <AMReX_BCUtil.H>
55
#include <AMReX_FillPatchUtil.H>
66
#include <AMReX_PhysBCFunct.H>
7-
#include <AlignedAllocator.h>
8-
#include <Dfft.H>
9-
#include <Distribution.H>
7+
// #include <AlignedAllocator.h>
8+
// #include <Dfft.H>
9+
// #include <Distribution.H>
10+
#include <AMReX_FFT.H>
1011
#include <iterator>
1112

1213
#include "hdf5_utils.h"
@@ -144,27 +145,6 @@ static void Fft(const amrex::MultiFab &field, const int comp,
144145
const amrex::Box original_V = original_ba.minimalBox();
145146
const int N = original_V.length(0);
146147
const int N_padded = N * zero_padding;
147-
/*
148-
const amrex::Box padded_V(
149-
amrex::IntVect(0, 0, 0),
150-
amrex::IntVect(N_padded - 1, N_padded - 1, N_padded - 1));
151-
const amrex::BoxArray padded_V_ba(padded_V);
152-
amrex::BoxList extra_bl = padded_V_ba.complementIn(original_V);
153-
amrex::BoxArray extra_ba(std::move(extra_bl));
154-
ChopGrids(extra_ba, amrex::ParallelDescriptor::NProcs());
155-
amrex::BoxList tmp_padded_bl = original_ba.boxList();
156-
tmp_padded_bl.join(extra_ba.boxList());
157-
amrex::BoxArray tmp_padded_ba(tmp_padded_bl);
158-
159-
// Get new crude DistributionMap
160-
amrex::DistributionMapping extra_dm(extra_ba,
161-
amrex::ParallelDescriptor::NProcs());
162-
amrex::Vector<int> extra_pmap = extra_dm.ProcessorMap();
163-
amrex::Vector<int> tmp_padded_pmap = original_pmap;
164-
std::move(extra_pmap.begin(), extra_pmap.end(),
165-
std::back_inserter(tmp_padded_pmap));
166-
amrex::DistributionMapping tmp_padded_dm(tmp_padded_pmap);
167-
*/
168148

169149
amrex::Vector<int> tmp_padded_pmap;
170150
amrex::BoxList tmp_padded_bl;
@@ -189,24 +169,11 @@ static void Fft(const amrex::MultiFab &field, const int comp,
189169
padded_geom.refine(
190170
amrex::IntVect(zero_padding, zero_padding, zero_padding));
191171

192-
// amrex::Print() << "orignal ba\n" << original_ba << std::endl;
193-
// amrex::Print() << "original dm\n" << field.DistributionMap() <<
194-
// std::endl; amrex::Print() << "original geom\n" << geom << std::endl;
195-
// amrex::Print() << "padded ba\n" << tmp_padded_ba << std::endl;
196-
// amrex::Print() << "padded dm\n" << tmp_padded_dm << std::endl;
197-
// amrex::Print() << "padded geom\n" << padded_geom << std::endl;
198-
199172
amrex::MultiFab tmp_field(original_ba, field.DistributionMap(), 1, 0);
200173
tmp_field.ParallelCopy(field, comp, 0, 1, 0, 0);
201174
amrex::MultiFab tmp_padded_field(tmp_padded_ba, tmp_padded_dm, 1, 0,
202175
amrex::MFInfo().SetAlloc(false));
203176

204-
// amrex::Print() << "nans: " << field.contains_nan() << " "
205-
// << tmp_field.contains_nan() << " " << std::flush
206-
// << std::endl;
207-
// amrex::Print() << "**********************************************"
208-
// << std::endl;
209-
210177
const int offset = original_ba.size();
211178
for (amrex::MFIter mfi(tmp_padded_field); mfi.isValid(); ++mfi) {
212179
if (mfi.index() < offset) {
@@ -244,158 +211,50 @@ static void Fft(const amrex::MultiFab &field, const int comp,
244211

245212
amrex::FillPatchSingleLevel(padded_field, 0, smf, stime, 0, 0, 1,
246213
padded_geom, physbc, 0);
214+
// Use the new amrex::FFT setup
215+
amrex::Box domain = padded_ba.minimalBox();
216+
amrex::FFT::R2C my_fft(domain);
247217

248-
if (false) {
249-
std::string subfolder =
250-
"debug_output/fft_" + std::to_string(++fft_counter);
251-
amrex::UtilCreateDirectory(subfolder.c_str(), 0755);
252-
253-
std::string filename =
254-
subfolder + "/" +
255-
std::to_string(amrex::ParallelDescriptor::MyProc()) + ".hdf5";
256-
hid_t file_id = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
257-
H5P_DEFAULT);
258-
259-
amrex::Print() << "Writing " << subfolder << std::endl;
260-
WriteThis(&field, file_id, comp);
218+
// create storage for the FFT
219+
auto const &[cba, cdm] = my_fft.getSpectralDataLayout();
261220

262-
H5Fclose(file_id);
263-
}
264-
if (false) {
265-
std::string subfolder =
266-
"debug_output/fft_" + std::to_string(++fft_counter);
267-
amrex::UtilCreateDirectory(subfolder.c_str(), 0755);
268-
269-
std::string filename =
270-
subfolder + "/" +
271-
std::to_string(amrex::ParallelDescriptor::MyProc()) + ".hdf5";
272-
hid_t file_id = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
273-
H5P_DEFAULT);
274-
275-
amrex::Print() << "Writing " << subfolder << std::endl;
276-
WriteThis(&tmp_padded_field, file_id, 0);
221+
// amrex::Print() << "padded ba" << padded_ba << std::endl;
222+
// amrex::Print() << "cba" << cba << std::endl;
277223

278-
H5Fclose(file_id);
279-
}
280-
if (false) {
281-
std::string subfolder =
282-
"debug_output/fft_" + std::to_string(++fft_counter);
283-
amrex::UtilCreateDirectory(subfolder.c_str(), 0755);
284-
285-
std::string filename =
286-
subfolder + "/" +
287-
std::to_string(amrex::ParallelDescriptor::MyProc()) + ".hdf5";
288-
hid_t file_id = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
289-
H5P_DEFAULT);
224+
field_fft_real_or_abs.define(cba, cdm, 1, 0);
225+
field_fft_imag.define(cba, cdm, 1, 0);
290226

291-
amrex::Print() << "Writing " << subfolder << std::endl;
292-
WriteThis(&padded_field, file_id, 0);
227+
amrex::FabArray<amrex::BaseFab<amrex::GpuComplex<amrex::Real>>> phi_fft(
228+
cba, cdm, 1, 0);
229+
my_fft.forward(padded_field, phi_fft);
293230

294-
H5Fclose(file_id);
295-
}
231+
for (amrex::MFIter mfi(phi_fft); mfi.isValid(); ++mfi) {
296232

297-
// amrex::Print() << "nans: " << field.contains_nan() << " "
298-
// << padded_field.contains_nan() << std::flush <<
299-
// std::endl;
300-
// amrex::Print() << "**********************************************"
301-
// << std::endl;
233+
amrex::Array4<amrex::GpuComplex<amrex::Real>> const &phi_fft_ptr =
234+
phi_fft.array(mfi);
235+
amrex::Array4<amrex::Real> real_or_abs =
236+
field_fft_real_or_abs.array(mfi);
237+
amrex::Array4<amrex::Real> imag = field_fft_imag.array(mfi);
302238

303-
/*
304-
// Get crude zero-padded boxArray.
305-
const amrex::BoxArray &original_ba = field.boxArray();
306-
amrex::Box minimal_box = original_ba.minimalBox();
307-
minimal_box.refine(zero_padding);
308-
// check this yields volume outside of ba domain.
309-
amrex::BoxList extra_bl = original_ba.complementIn(minimal_box);
310-
amrex::BoxArray extra_ba(std::move(extra_bl));
311-
ChopGrids(extra_ba, amrex::ParallelDescriptor::NProcs());
312-
amrex::BoxList new_bl = original_ba.boxList();
313-
new_bl.join(extra_ba.boxList());
314-
amrex::BoxArray new_ba(new_bl);
315-
316-
// Get new crude DistributionMap
317-
const amrex::DistributionMapping &original_dm = field.DistributionMap();
318-
amrex::DistributionMapping extra_dm(extra_ba,
319-
amrex::ParallelDescriptor::NProcs());
320-
amrex::Vector<int> extra_pmap = extra_dm.ProcessorMap();
321-
amrex::Vector<int> new_pmap = original_dm.ProcessorMap();
322-
std::move(extra_pmap.begin(), extra_pmap.end(),
323-
std::back_inserter(new_pmap));
324-
amrex::DistributionMapping new_dm(new_pmap);
325-
326-
// Fill new crude MultiFab
327-
amrex::MultiFab field_tmp(new_ba, new_dm, 1, 0);
328-
for (amrex::MFIter mfi(field, false); mfi.isValid(); ++mfi) {
329-
const amrex::Box &bx = mfi.tilebox();
330-
const auto &state_arr = field.array(mfi);
331-
const auto &field_tmp_arr = field_tmp.array(mfi);
239+
const amrex::Box &bx = mfi.fabbox();
332240

333-
amrex::ParallelFor(
334-
bx, [=] AMREX_GPU_DEVICE(int i, int j, int k) noexcept {
335-
field_tmp_arr(i, j, k, 0) = state_arr(i, j, k, comp);
241+
if (abs) {
242+
amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j,
243+
int k) noexcept {
244+
real_or_abs(i, j, k, 0) = std::sqrt(
245+
phi_fft_ptr(i, j, k).real() * phi_fft_ptr(i, j, k).real() +
246+
phi_fft_ptr(i, j, k).imag() * phi_fft_ptr(i, j, k).imag());
336247
});
248+
} else {
249+
amrex::ParallelFor(
250+
bx, [=] AMREX_GPU_DEVICE(int i, int j, int k) noexcept {
251+
real_or_abs(i, j, k, 0) = phi_fft_ptr(i, j, k).real();
252+
imag(i, j, k, 0) = phi_fft_ptr(i, j, k).imag();
253+
});
254+
}
337255
}
338256

339-
// Now refine layout so we can feed it into SWFFT
340-
amrex::BoxArray padded_ba(minimal_box);
341-
ChopGrids(padded_ba, amrex::ParallelDescriptor::NProcs());
342-
amrex::DistributionMapping padded_dm(padded_ba,
343-
amrex::ParallelDescriptor::NProcs());
344-
amrex::Geometry padded_geom(geom);
345-
padded_geom.refine(
346-
amrex::IntVect(zero_padding, zero_padding, zero_padding));
347-
348-
amrex::MultiFab padded_field(padded_ba, padded_dm, 1, 0);
349-
350-
amrex::Print() << "********************* bef" <<
351-
field_tmp.contains_nan()
352-
<< " " << padded_field.contains_nan() << std::endl;
353-
padded_field.ParallelCopy(field_tmp);
354-
amrex::Print() << "********************* aft" <<
355-
field_tmp.contains_nan()
356-
<< " " << padded_field.contains_nan() << std::endl;
357-
*/
358-
/*
359-
amrex::Vector<amrex::BCRec> bcs;
360-
bcs.resize(1);
361-
for (int i = 0; i < AMREX_SPACEDIM; ++i) {
362-
bcs[0].setLo(i, amrex::BCType::int_dir);
363-
bcs[0].setHi(i, amrex::BCType::int_dir);
364-
}
365-
366-
amrex::CpuBndryFuncFab bndry_func(nullptr);
367-
amrex::PhysBCFunct<amrex::CpuBndryFuncFab> physbc(padded_geom, bcs,
368-
bndry_func);
369-
amrex::Vector<amrex::MultiFab *> smf{
370-
static_cast<amrex::MultiFab *>(&field_tmp)};
371-
amrex::Vector<double> stime{0};
372-
373-
amrex::FillPatchSingleLevel(padded_field, 0, smf, stime, 0, 0, 1,
374-
padded_geom, physbc, 0);
375-
*/
376-
// tmp_field.clear();
377-
field_fft_real_or_abs.define(padded_ba, padded_dm, 1, 0);
378-
field_fft_imag.define(padded_ba, padded_dm, 1, 0);
379-
380-
// Debug output
381257
/*
382-
amrex::Print() << "original_ba:\n"
383-
<< original_ba << std::flush << std::endl;
384-
amrex::Print() << "minimal_box:\n" << minimal_box << std::endl;
385-
amrex::Print() << "extra ba:\n" << extra_ba << std::flush << std::endl;
386-
amrex::Print() << "new ba:\n" << new_ba << std::flush << std::endl;
387-
amrex::Print() << "original dm\n" << original_dm << std::flush << std::endl;
388-
amrex::Print() << "extra dm\n" << extra_dm << std::flush << std::endl;
389-
amrex::Print() << "new dm\n" << new_dm << std::flush << std::endl;
390-
amrex::Print() << "padded ba\n" << padded_ba << std::flush << std::endl;
391-
amrex::Print() << "padded dm\n" << padded_dm << std::flush << std::endl;
392-
amrex::Print() << "contains nan tmp field\n"
393-
<< field_tmp.contains_nan() << std::flush << std::endl;
394-
amrex::Print() << "contains nan padded field\n"
395-
<< padded_field.contains_nan() << std::flush << std::endl;
396-
amrex::Print() << "padded geom domains\n"
397-
<< padded_geom.Domain() << std::flush << std::endl;
398-
*/
399258
// Now setup SWFFT
400259
int nx = padded_ba[0].size()[0];
401260
int ny = padded_ba[0].size()[1];
@@ -481,6 +340,7 @@ static void Fft(const amrex::MultiFab &field, const int comp,
481340
}
482341
}
483342
}
343+
*/
484344
}
485345

486346
}; // namespace utils

0 commit comments

Comments
 (0)