Skip to content

Commit b13f68c

Browse files
BenWibkingpre-commit-ci[bot]ax3l
authored
add PlotFileData bindings (#320)
* add PlotFileData bindings * Update PlotFileUtil.cpp * Update PlotFileUtil.cpp * fix bugs * fix bindings for overloaded function * add unit test for simple single-level 3D plotfile * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix box, string asserts * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint error * add script to plot 2d plotfiles * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Axel Huebl <[email protected]> * Tests: Import pytest * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Import amr * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * generate test data on-the-fly * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove extra import * remove extra import (again) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Axel Huebl <[email protected]>
1 parent 22124ea commit b13f68c

File tree

3 files changed

+164
-17
lines changed

3 files changed

+164
-17
lines changed

src/Base/PlotFileUtil.cpp

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,47 @@
55
#include "pyAMReX.H"
66

77
#include <AMReX_PlotFileUtil.H>
8-
#include <AMReX_Vector.H>
98
#include <AMReX_Print.H>
9+
#include <AMReX_Vector.H>
1010

1111
#include <sstream>
1212
#include <string>
1313

1414
namespace py = pybind11;
1515
using namespace amrex;
1616

17-
void init_PlotFileUtil(py::module& m)
18-
{
19-
m.def("write_single_level_plotfile",
20-
&amrex::WriteSingleLevelPlotfile,
21-
"Writes single level plotfile",
22-
py::arg("plotfilename"),
23-
py::arg("mf"),
24-
py::arg("varnames"),
25-
py::arg("geom"),
26-
py::arg("time"),
27-
py::arg("level_step"),
28-
py::arg("versionName")="HyperCLaw-V1.1",
29-
py::arg("levelPrefix")="Level_",
30-
py::arg("mfPrefix")="Cell",
31-
py::arg_v("extra_dirs", Vector<std::string>(), "list[str]")
32-
);
17+
void init_PlotFileUtil(py::module &m) {
18+
m.def("write_single_level_plotfile", &amrex::WriteSingleLevelPlotfile,
19+
"Writes single level plotfile", py::arg("plotfilename"), py::arg("mf"),
20+
py::arg("varnames"), py::arg("geom"), py::arg("time"),
21+
py::arg("level_step"), py::arg("versionName") = "HyperCLaw-V1.1",
22+
py::arg("levelPrefix") = "Level_", py::arg("mfPrefix") = "Cell",
23+
py::arg_v("extra_dirs", Vector<std::string>(), "list[str]"));
24+
25+
py::class_<PlotFileData>(m, "PlotFileData")
26+
// explicitly provide constructor argument types
27+
.def(py::init<std::string const&>())
28+
29+
.def("spaceDim", &PlotFileData::spaceDim)
30+
.def("time", &PlotFileData::time)
31+
.def("finestLevel", &PlotFileData::finestLevel)
32+
.def("refRatio", &PlotFileData::refRatio)
33+
.def("levelStep", &PlotFileData::levelStep)
34+
.def("boxArray", &PlotFileData::boxArray)
35+
.def("DistributionMap", &PlotFileData::DistributionMap)
36+
.def("syncDistributionMap", py::overload_cast<PlotFileData const&>(&PlotFileData::syncDistributionMap))
37+
.def("syncDistributionMap", py::overload_cast<int, PlotFileData const&>(&PlotFileData::syncDistributionMap))
38+
39+
.def("coordSys", &PlotFileData::coordSys)
40+
.def("probDomain", &PlotFileData::probDomain)
41+
.def("probSize", &PlotFileData::probSize)
42+
.def("probLo", &PlotFileData::probLo)
43+
.def("probHi", &PlotFileData::probHi)
44+
.def("cellSize", &PlotFileData::cellSize)
45+
.def("varNames", &PlotFileData::varNames)
46+
.def("nComp", &PlotFileData::nComp)
47+
.def("nGrowVect", &PlotFileData::nGrowVect)
48+
49+
.def("get", py::overload_cast<int>(&PlotFileData::get))
50+
.def("get", py::overload_cast<int, std::string const&>(&PlotFileData::get));
3351
}

tests/test_plotfiledata.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import numpy as np
4+
import pytest
5+
6+
import amrex.space3d as amr
7+
8+
9+
def write_test_plotfile(filename):
10+
"""Write single-level plotfile (in order to read it back in)."""
11+
domain_box = amr.Box([0, 0, 0], [31, 31, 31])
12+
real_box = amr.RealBox([-0.5, -0.5, -0.5], [0.5, 0.5, 0.5])
13+
geom = amr.Geometry(domain_box, real_box, amr.CoordSys.cartesian, [0, 0, 0])
14+
15+
ba = amr.BoxArray(domain_box)
16+
dm = amr.DistributionMapping(ba, 1)
17+
mf = amr.MultiFab(ba, dm, 1, 0)
18+
mf.set_val(np.pi)
19+
20+
time = 1.0
21+
level_step = 200
22+
var_names = amr.Vector_string(["density"])
23+
amr.write_single_level_plotfile(filename, mf, var_names, geom, time, level_step)
24+
25+
26+
@pytest.mark.skipif(amr.Config.spacedim != 3, reason="Requires AMREX_SPACEDIM = 3")
27+
def test_plotfiledata_read():
28+
"""Generate and then read plotfile using PlotFileUtil bindings."""
29+
plt_filename = "test_plt00200"
30+
write_test_plotfile(plt_filename)
31+
plt = amr.PlotFileData(plt_filename)
32+
33+
assert plt.spaceDim() == 3
34+
assert plt.time() == 1.0
35+
assert plt.finestLevel() == 0
36+
assert plt.refRatio(0) == 0
37+
assert plt.coordSys() == amr.CoordSys.cartesian
38+
39+
probDomain = plt.probDomain(0)
40+
probSize = plt.probSize()
41+
probLo = plt.probLo()
42+
probHi = plt.probHi()
43+
cellSize = plt.cellSize(0)
44+
varNames = plt.varNames()
45+
nComp = plt.nComp()
46+
nGrowVect = plt.nGrowVect(0)
47+
48+
assert probDomain.small_end == amr.IntVect(0, 0, 0)
49+
assert probDomain.big_end == amr.IntVect(31, 31, 31)
50+
51+
assert probSize == [1.0, 1.0, 1.0]
52+
assert probLo == [-0.5, -0.5, -0.5]
53+
assert probHi == [0.5, 0.5, 0.5]
54+
assert cellSize == [1.0 / 32.0, 1.0 / 32.0, 1.0 / 32.0]
55+
assert varNames == amr.Vector_string(["density"])
56+
assert nComp == 1
57+
assert nGrowVect == amr.IntVect(0, 0, 0)
58+
59+
for compname in varNames:
60+
mfab_comp = plt.get(0, compname)
61+
nboxes = 0
62+
63+
for mfi in mfab_comp:
64+
marr = mfab_comp.array(mfi)
65+
# numpy/cupy representation: non-copying view, including the
66+
# guard/ghost region
67+
marr_xp = marr.to_xp()
68+
assert marr_xp.shape == (32, 32, 32, 1)
69+
assert np.all(marr_xp[:, :, :, :] == np.pi)
70+
nboxes += 1
71+
72+
assert nboxes == 1

tools/plot_plotfile_2d.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
import pytest
4+
5+
import amrex.space2d as amr
6+
7+
8+
def plot_mf(arr, compname, plo, phi):
9+
plt.plot()
10+
im = plt.imshow(
11+
arr.T,
12+
origin="lower",
13+
interpolation="none",
14+
extent=[plo[0], phi[0], plo[1], phi[1]],
15+
aspect="equal",
16+
)
17+
plt.colorbar(im)
18+
plt.tight_layout()
19+
plt.savefig(f"{compname}.png", dpi=150)
20+
plt.close()
21+
22+
23+
@pytest.mark.skipif(amr.Config.spacedim != 2, reason="Requires AMREX_SPACEDIM = 2")
24+
def plot_plotfile_2d(filename, level=0):
25+
plt = amr.PlotFileData(filename)
26+
assert level <= plt.finestLevel()
27+
28+
probDomain = plt.probDomain(level)
29+
probLo = plt.probLo()
30+
probHi = plt.probHi()
31+
varNames = plt.varNames()
32+
33+
for compname in varNames:
34+
mfab_comp = plt.get(level, compname)
35+
arr = np.zeros((probDomain.big_end - probDomain.small_end + 1))
36+
for mfi in mfab_comp:
37+
bx = mfi.tilebox()
38+
marr = mfab_comp.array(mfi)
39+
marr_xp = marr.to_xp()
40+
i_s, j_s = tuple(bx.small_end)
41+
i_e, j_e = tuple(bx.big_end)
42+
arr[i_s : i_e + 1, j_s : j_e + 1] = marr_xp[:, :, 0, 0]
43+
plot_mf(arr, compname, probLo, probHi)
44+
45+
46+
if __name__ == "__main__":
47+
import argparse
48+
49+
parser = argparse.ArgumentParser(
50+
"Plots each variable in a 2D plotfile using matplotlib."
51+
)
52+
parser.add_argument("filename", help="AMReX 2D plotfile to read")
53+
parser.add_argument("level", type=int, help="AMR level to plot (default: 0)")
54+
args = parser.parse_args()
55+
56+
amr.initialize([])
57+
plot_plotfile_2d(args.filename, level=args.level)

0 commit comments

Comments
 (0)