Skip to content

Commit 4064d5e

Browse files
committed
Pandas: ParticleContainer_*.to_df()
Copy all particles into a `pandas.DataFrame`. Supports local and MPI-gathered results.
1 parent a47db85 commit 4064d5e

File tree

6 files changed

+146
-0
lines changed

6 files changed

+146
-0
lines changed

src/Particle/ParticleContainer.H

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ void make_Base_Iterators (py::module &m, std::string allocstr)
6868
py::return_value_policy::reference_internal)
6969

7070
.def_property_readonly_static("is_soa_particle", [](const py::object&){ return ParticleType::is_soa_particle;})
71+
.def_property_readonly("size", &iterator_base::numParticles,
72+
"the number of particles on this tile")
7173
.def_property_readonly("num_particles", &iterator_base::numParticles)
7274
.def_property_readonly("num_real_particles", &iterator_base::numRealParticles)
7375
.def_property_readonly("num_neighbor_particles", &iterator_base::numNeighborParticles)
@@ -382,6 +384,14 @@ void make_ParticleContainer_and_Iterators (py::module &m, std::string allocstr)
382384
make_Iterators< false, iterator, Allocator >(m, allocstr);
383385
using const_iterator = amrex::ParConstIter_impl<ParticleType, T_NArrayReal, T_NArrayInt, Allocator>;
384386
make_Iterators< true, const_iterator, Allocator >(m, allocstr);
387+
388+
// simpler particle iterator loops: return types of this particle box
389+
py_pc
390+
.def_property_readonly_static("iterator", [](py::object /* pc */){ return py::type::of<iterator>(); },
391+
"amrex iterator for particle boxes")
392+
.def_property_readonly_static("const_iterator", [](py::object /* pc */){ return py::type::of<const_iterator>(); },
393+
"amrex constant iterator for particle boxes (read-only)")
394+
;
385395
}
386396

387397
/** Create ParticleContainers and Iterators

src/amrex/ParticleContainer.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""
2+
This file is part of pyAMReX
3+
4+
Copyright 2023 AMReX community
5+
Authors: Axel Huebl
6+
License: BSD-3-Clause-LBNL
7+
"""
8+
9+
10+
def pc_to_df(self, local=True, comm=None, root_rank=0):
11+
"""
12+
Copy all particles into a pandas.DataFrame
13+
14+
Parameters
15+
----------
16+
self : amrex.ParticleContainer_*
17+
A ParticleContainer class in pyAMReX
18+
local : bool
19+
MPI-local particles
20+
comm : MPI Communicator
21+
if local is False, this defaults to mpi4py.MPI.COMM_WORLD
22+
root_rank : MPI root rank to gather to
23+
if local is False, this defaults to 0
24+
25+
Returns
26+
-------
27+
A concatenated pandas.DataFrame with particles from all levels.
28+
29+
Returns None if no particles were found.
30+
If local=False, then all ranks but the root_rank will return None.
31+
"""
32+
import pandas as pd
33+
34+
# create a DataFrame per particle box and append it to the list of
35+
# local DataFrame(s)
36+
dfs_local = []
37+
for lvl in range(self.finest_level + 1):
38+
for pti in self.const_iterator(self, level=lvl):
39+
if pti.size == 0:
40+
continue
41+
42+
if self.is_soa_particle:
43+
next_df = pd.DataFrame()
44+
else:
45+
# AoS
46+
aos_np = pti.aos().to_numpy(copy=True)
47+
next_df = pd.DataFrame(aos_np)
48+
next_df.set_index("cpuid")
49+
next_df.index.name = "cpuid"
50+
51+
# SoA
52+
soa_view = pti.soa().to_numpy(copy=True)
53+
soa_np_real = soa_view.real
54+
soa_np_int = soa_view.int
55+
56+
for idx, array in enumerate(soa_np_real):
57+
next_df[f"SoA_real_{idx}"] = array
58+
for idx, array in enumerate(soa_np_int):
59+
next_df[f"SoA_int_{idx}"] = array
60+
61+
dfs_local.append(next_df)
62+
63+
# MPI Gather to root rank if requested
64+
if local:
65+
if len(dfs_local) == 0:
66+
df = None
67+
else:
68+
df = pd.concat(dfs_local)
69+
else:
70+
from mpi4py import MPI
71+
72+
if comm is None:
73+
comm = MPI.COMM_WORLD
74+
rank = comm.Get_rank()
75+
76+
# a list for each rank's list of DataFrame(s)
77+
df_list_list = comm.gather(dfs_local, root=root_rank)
78+
79+
if rank == root_rank:
80+
flattened_list = [df for sublist in df_list_list for df in sublist]
81+
82+
if len(flattened_list) == 0:
83+
df = pd.DataFrame()
84+
else:
85+
df = pd.concat(flattened_list, ignore_index=True)
86+
else:
87+
df = None
88+
89+
return df
90+
91+
92+
def register_ParticleContainer_extension(amr):
93+
"""ParticleContainer helper methods"""
94+
import inspect
95+
import sys
96+
97+
# register member functions for every ParticleContainer_* type
98+
for _, ParticleContainer_type in inspect.getmembers(
99+
sys.modules[amr.__name__],
100+
lambda member: inspect.isclass(member)
101+
and member.__module__ == amr.__name__
102+
and member.__name__.startswith("ParticleContainer_"),
103+
):
104+
ParticleContainer_type.to_df = pc_to_df

src/amrex/space1d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ def Print(*args, **kwargs):
4848
from ..ArrayOfStructs import register_AoS_extension
4949
from ..MultiFab import register_MultiFab_extension
5050
from ..PODVector import register_PODVector_extension
51+
from ..ParticleContainer import register_ParticleContainer_extension
5152
from ..StructOfArrays import register_SoA_extension
5253

5354
register_Array4_extension(amrex_1d_pybind)
5455
register_MultiFab_extension(amrex_1d_pybind)
5556
register_PODVector_extension(amrex_1d_pybind)
5657
register_SoA_extension(amrex_1d_pybind)
5758
register_AoS_extension(amrex_1d_pybind)
59+
register_ParticleContainer_extension(amrex_1d_pybind)

src/amrex/space2d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ def Print(*args, **kwargs):
4848
from ..ArrayOfStructs import register_AoS_extension
4949
from ..MultiFab import register_MultiFab_extension
5050
from ..PODVector import register_PODVector_extension
51+
from ..ParticleContainer import register_ParticleContainer_extension
5152
from ..StructOfArrays import register_SoA_extension
5253

5354
register_Array4_extension(amrex_2d_pybind)
5455
register_MultiFab_extension(amrex_2d_pybind)
5556
register_PODVector_extension(amrex_2d_pybind)
5657
register_SoA_extension(amrex_2d_pybind)
5758
register_AoS_extension(amrex_2d_pybind)
59+
register_ParticleContainer_extension(amrex_2d_pybind)

src/amrex/space3d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ def Print(*args, **kwargs):
4848
from ..ArrayOfStructs import register_AoS_extension
4949
from ..MultiFab import register_MultiFab_extension
5050
from ..PODVector import register_PODVector_extension
51+
from ..ParticleContainer import register_ParticleContainer_extension
5152
from ..StructOfArrays import register_SoA_extension
5253

5354
register_Array4_extension(amrex_3d_pybind)
5455
register_MultiFab_extension(amrex_3d_pybind)
5556
register_PODVector_extension(amrex_3d_pybind)
5657
register_SoA_extension(amrex_3d_pybind)
5758
register_AoS_extension(amrex_3d_pybind)
59+
register_ParticleContainer_extension(amrex_3d_pybind)

tests/test_particleContainer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,29 @@ def test_per_cell(empty_particle_container, std_geometry, std_particle):
272272
assert pc.TotalNumberOfParticles() == pc.NumberOfParticlesAtLevel(0) == ncells
273273
print("npts * real_1", ncells * std_particle.real_array_data[1])
274274
assert ncells * std_particle.real_array_data[1] == sum_1
275+
276+
277+
def test_pc_df(particle_container, Npart):
278+
pc = particle_container
279+
print(f"pc={pc}")
280+
df = pc.to_df()
281+
print(df.columns)
282+
print(df)
283+
284+
285+
def test_pc_empty_df(empty_particle_container, Npart):
286+
pc = empty_particle_container
287+
print(f"pc={pc}")
288+
df = pc.to_df()
289+
assert df is None
290+
291+
292+
@pytest.mark.skipif(not amr.Config.have_mpi, reason="Requires AMReX_MPI=ON")
293+
def test_pc_df_mpi(particle_container, Npart):
294+
pc = particle_container
295+
print(f"pc={pc}")
296+
df = pc.to_df(local=False)
297+
if df is not None:
298+
# only rank 0
299+
print(df.columns)
300+
print(df)

0 commit comments

Comments
 (0)