Skip to content

Commit edac832

Browse files
committed
MultiFab: to_numpy/cupy
Add numpy & cupy helpers for MultiFab.
1 parent 596f0e7 commit edac832

File tree

5 files changed

+36
-17
lines changed

5 files changed

+36
-17
lines changed

src/amrex/MultiFab.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
License: BSD-3-Clause-LBNL
77
"""
88

9-
10-
def mf_to_numpy(self, copy=False, order="F"):
9+
def mf_to_numpy(amr, self, copy=False, order="F"):
1110
"""
1211
Provide a Numpy view into a MultiFab.
1312
@@ -29,13 +28,24 @@ def mf_to_numpy(self, copy=False, order="F"):
2928
3029
Returns
3130
-------
32-
list of np.array
31+
list of numpy.array
3332
A list of numpy n-dimensional arrays, for each local block in the
3433
MultiFab.
3534
"""
35+
mf = self
36+
if copy:
37+
mf = amr.MultiFab(
38+
self.box_array(),
39+
self.dm(),
40+
self.n_comp(),
41+
self.n_grow_vect(),
42+
amr.MFInfo().set_arena(amr.The_Pinned_Arena()),
43+
)
44+
amr.dtoh_memcpy(mf, self)
45+
3646
views = []
37-
for mfi in self:
38-
views.append(self.array(mfi).to_numpy(copy, order))
47+
for mfi in mf:
48+
views.append(mf.array(mfi).to_numpy(copy=False, order=order))
3949

4050
return views
4151

@@ -80,15 +90,9 @@ def mf_to_cupy(self, copy=False, order="F"):
8090

8191
def register_MultiFab_extension(amr):
8292
"""MultiFab helper methods"""
83-
import inspect
84-
import sys
85-
86-
# register member functions for every MultiFab* type
87-
for _, MultiFab_type in inspect.getmembers(
88-
sys.modules[amr.__name__],
89-
lambda member: inspect.isclass(member)
90-
and member.__module__ == amr.__name__
91-
and member.__name__.startswith("MultiFab"),
92-
):
93-
MultiFab_type.to_numpy = mf_to_numpy
94-
MultiFab_type.to_cupy = mf_to_cupy
93+
94+
# register member functions for the MultiFab type
95+
amr.MultiFab.to_numpy = lambda self, copy=False, order="F": mf_to_numpy(amr, self, copy, order)
96+
amr.MultiFab.to_numpy.__doc__ = mf_to_numpy.__doc__
97+
98+
amr.MultiFab.to_cupy = mf_to_cupy

src/amrex/space1d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@ def Print(*args, **kwargs):
4545

4646

4747
from ..Array4 import register_Array4_extension
48+
from ..MultiFab import register_MultiFab_extension
4849
from ..ArrayOfStructs import register_AoS_extension
4950
from ..PODVector import register_PODVector_extension
5051
from ..StructOfArrays import register_SoA_extension
5152

5253
register_Array4_extension(amrex_1d_pybind)
54+
register_MultiFab_extension(amrex_1d_pybind)
5355
register_PODVector_extension(amrex_1d_pybind)
5456
register_SoA_extension(amrex_1d_pybind)
5557
register_AoS_extension(amrex_1d_pybind)

src/amrex/space2d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@ def Print(*args, **kwargs):
4545

4646

4747
from ..Array4 import register_Array4_extension
48+
from ..MultiFab import register_MultiFab_extension
4849
from ..ArrayOfStructs import register_AoS_extension
4950
from ..PODVector import register_PODVector_extension
5051
from ..StructOfArrays import register_SoA_extension
5152

5253
register_Array4_extension(amrex_2d_pybind)
54+
register_MultiFab_extension(amrex_2d_pybind)
5355
register_PODVector_extension(amrex_2d_pybind)
5456
register_SoA_extension(amrex_2d_pybind)
5557
register_AoS_extension(amrex_2d_pybind)

src/amrex/space3d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@ def Print(*args, **kwargs):
4545

4646

4747
from ..Array4 import register_Array4_extension
48+
from ..MultiFab import register_MultiFab_extension
4849
from ..ArrayOfStructs import register_AoS_extension
4950
from ..PODVector import register_PODVector_extension
5051
from ..StructOfArrays import register_SoA_extension
5152

5253
register_Array4_extension(amrex_3d_pybind)
54+
register_MultiFab_extension(amrex_3d_pybind)
5355
register_PODVector_extension(amrex_3d_pybind)
5456
register_SoA_extension(amrex_3d_pybind)
5557
register_AoS_extension(amrex_3d_pybind)

tests/test_multifab.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,3 +350,12 @@ def test_mfab_dtoh_copy(make_mfab_device):
350350
device_max = mfab_device.max(0)
351351
assert device_min == device_max
352352
assert device_max == 11.0
353+
354+
# numpy bindings (w/ copy)
355+
local_boxes_host = mfab_device.to_numpy(copy=True)
356+
assert max([np.max(box) for box in local_boxes_host]) == device_max
357+
358+
# cupy bindings (w/o copy)
359+
import cupy as cp
360+
local_boxes_device = mfab_device.to_cupy()
361+
assert max([cp.max(box) for box in local_boxes_device]) == device_max

0 commit comments

Comments
 (0)