Skip to content

Commit 3c73a42

Browse files
authored
MultiFab Fixture Cleanup via FabArray::clear (#214)
* MultiFab Fixture Cleanup via `FabArray::clear` Using a context manager and calling clear ensures that we will not hold device memory anymore once we hit `AMReX::Finalize`, even in the situation where an exception is raised in a test. This avoids segfaults for failing tests. * `test_mfab_dtoh_copy`: Clear MFabs Clear out memory safely on runtime errors. * Update Stub Files --------- Co-authored-by: ax3l <[email protected]>
1 parent 056d332 commit 3c73a42

File tree

6 files changed

+103
-79
lines changed

6 files changed

+103
-79
lines changed

src/Base/MultiFab.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ void init_MultiFab(py::module &m)
127127
;
128128

129129
py_FabArray_FArrayBox
130+
// define
131+
.def("clear", &FabArray<FArrayBox>::clear)
132+
.def("ok", &FabArray<FArrayBox>::ok)
133+
130134
//.def("array", py::overload_cast< const MFIter& >(&FabArray<FArrayBox>::array))
131135
//.def("const_array", &FabArray<FArrayBox>::const_array)
132136
.def("array", [](FabArray<FArrayBox> & fa, MFIter const & mfi)

src/amrex/space1d/amrex_1d_pybind/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3755,6 +3755,7 @@ class FabArray_FArrayBox(FabArrayBase):
37553755
arg6: IntVect,
37563756
) -> None: ...
37573757
def array(self, arg0: MFIter) -> Array4_double: ...
3758+
def clear(self) -> None: ...
37583759
def const_array(self, arg0: MFIter) -> Array4_double_const: ...
37593760
@typing.overload
37603761
def fill_boundary(self, cross: bool = False) -> None: ...
@@ -3779,6 +3780,7 @@ class FabArray_FArrayBox(FabArrayBase):
37793780
period: Periodicity,
37803781
cross: bool = False,
37813782
) -> None: ...
3783+
def ok(self) -> bool: ...
37823784
def override_sync(self, arg0: Periodicity) -> None: ...
37833785
def sum(self, arg0: int, arg1: IntVect, arg2: bool) -> float: ...
37843786
@typing.overload

src/amrex/space2d/amrex_2d_pybind/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3755,6 +3755,7 @@ class FabArray_FArrayBox(FabArrayBase):
37553755
arg6: IntVect,
37563756
) -> None: ...
37573757
def array(self, arg0: MFIter) -> Array4_double: ...
3758+
def clear(self) -> None: ...
37583759
def const_array(self, arg0: MFIter) -> Array4_double_const: ...
37593760
@typing.overload
37603761
def fill_boundary(self, cross: bool = False) -> None: ...
@@ -3779,6 +3780,7 @@ class FabArray_FArrayBox(FabArrayBase):
37793780
period: Periodicity,
37803781
cross: bool = False,
37813782
) -> None: ...
3783+
def ok(self) -> bool: ...
37823784
def override_sync(self, arg0: Periodicity) -> None: ...
37833785
def sum(self, arg0: int, arg1: IntVect, arg2: bool) -> float: ...
37843786
@typing.overload

src/amrex/space3d/amrex_3d_pybind/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3755,6 +3755,7 @@ class FabArray_FArrayBox(FabArrayBase):
37553755
arg6: IntVect,
37563756
) -> None: ...
37573757
def array(self, arg0: MFIter) -> Array4_double: ...
3758+
def clear(self) -> None: ...
37583759
def const_array(self, arg0: MFIter) -> Array4_double_const: ...
37593760
@typing.overload
37603761
def fill_boundary(self, cross: bool = False) -> None: ...
@@ -3779,6 +3780,7 @@ class FabArray_FArrayBox(FabArrayBase):
37793780
period: Periodicity,
37803781
cross: bool = False,
37813782
) -> None: ...
3783+
def ok(self) -> bool: ...
37823784
def override_sync(self, arg0: Periodicity) -> None: ...
37833785
def sum(self, arg0: int, arg1: IntVect, arg2: bool) -> float: ...
37843786
@typing.overload

tests/conftest.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -85,47 +85,59 @@ def distmap(boxarr):
8585

8686

8787
@pytest.fixture(scope="function", params=list(itertools.product([1, 3], [0, 1])))
88-
def make_mfab(boxarr, distmap, request):
88+
def mfab(boxarr, distmap, request):
8989
"""MultiFab that is either managed or device:
9090
The MultiFab object itself is not a fixture because we want to avoid caching
9191
it between amr.initialize/finalize calls of various tests.
9292
https://github.com/pytest-dev/pytest/discussions/10387
9393
https://github.com/pytest-dev/pytest/issues/5642#issuecomment-1279612764
9494
"""
9595

96-
def create():
97-
num_components = request.param[0]
98-
num_ghost = request.param[1]
99-
mfab = amr.MultiFab(boxarr, distmap, num_components, num_ghost)
100-
mfab.set_val(0.0, 0, num_components)
101-
return mfab
96+
class MfabContextManager:
97+
def __enter__(self):
98+
num_components = request.param[0]
99+
num_ghost = request.param[1]
100+
self.mfab = amr.MultiFab(boxarr, distmap, num_components, num_ghost)
101+
self.mfab.set_val(0.0, 0, num_components)
102+
return self.mfab
102103

103-
return create
104+
def __exit__(self, exc_type, exc_value, traceback):
105+
self.mfab.clear()
106+
del self.mfab
107+
108+
with MfabContextManager() as mfab:
109+
yield mfab
104110

105111

106112
@pytest.mark.skipif(
107113
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
108114
)
109115
@pytest.fixture(scope="function", params=list(itertools.product([1, 3], [0, 1])))
110-
def make_mfab_device(boxarr, distmap, request):
116+
def mfab_device(boxarr, distmap, request):
111117
"""MultiFab that resides purely on the device:
112118
The MultiFab object itself is not a fixture because we want to avoid caching
113119
it between amr.initialize/finalize calls of various tests.
114120
https://github.com/pytest-dev/pytest/discussions/10387
115121
https://github.com/pytest-dev/pytest/issues/5642#issuecomment-1279612764
116122
"""
117123

118-
def create():
119-
num_components = request.param[0]
120-
num_ghost = request.param[1]
121-
mfab = amr.MultiFab(
122-
boxarr,
123-
distmap,
124-
num_components,
125-
num_ghost,
126-
amr.MFInfo().set_arena(amr.The_Device_Arena()),
127-
)
128-
mfab.set_val(0.0, 0, num_components)
129-
return mfab
130-
131-
return create
124+
class MfabDeviceContextManager:
125+
def __enter__(self):
126+
num_components = request.param[0]
127+
num_ghost = request.param[1]
128+
self.mfab = amr.MultiFab(
129+
boxarr,
130+
distmap,
131+
num_components,
132+
num_ghost,
133+
amr.MFInfo().set_arena(amr.The_Device_Arena()),
134+
)
135+
self.mfab.set_val(0.0, 0, num_components)
136+
return self.mfab
137+
138+
def __exit__(self, exc_type, exc_value, traceback):
139+
self.mfab.clear()
140+
del self.mfab
141+
142+
with MfabDeviceContextManager() as mfab:
143+
yield mfab

tests/test_multifab.py

Lines changed: 58 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import amrex.space3d as amr
99

1010

11-
def test_mfab_loop(make_mfab):
12-
mfab = make_mfab()
11+
def test_mfab_loop(mfab):
1312
ngv = mfab.nGrowVect
1413
print(f"\n mfab={mfab}, mfab.nGrowVect={ngv}")
1514

@@ -78,8 +77,7 @@ def test_mfab_loop(make_mfab):
7877
# TODO
7978

8079

81-
def test_mfab_simple(make_mfab):
82-
mfab = make_mfab()
80+
def test_mfab_simple(mfab):
8381
assert mfab.is_all_cell_centered
8482
# assert(all(not mfab.is_nodal(i) for i in [-1, 0, 1, 2])) # -1??
8583
assert all(not mfab.is_nodal(i) for i in [0, 1, 2])
@@ -144,8 +142,7 @@ def test_mfab_ops(boxarr, distmap, nghost):
144142
np.testing.assert_allclose(dst.max(0), 150.0)
145143

146144

147-
def test_mfab_mfiter(make_mfab):
148-
mfab = make_mfab()
145+
def test_mfab_mfiter(mfab):
149146
assert iter(mfab).is_valid
150147
assert iter(mfab).length == 8
151148

@@ -159,8 +156,7 @@ def test_mfab_mfiter(make_mfab):
159156
@pytest.mark.skipif(
160157
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
161158
)
162-
def test_mfab_ops_cuda_numba(make_mfab_device):
163-
mfab_device = make_mfab_device()
159+
def test_mfab_ops_cuda_numba(mfab_device):
164160
# https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html
165161
from numba import cuda
166162

@@ -195,8 +191,7 @@ def set_to_three(array):
195191
@pytest.mark.skipif(
196192
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
197193
)
198-
def test_mfab_ops_cuda_cupy(make_mfab_device):
199-
mfab_device = make_mfab_device()
194+
def test_mfab_ops_cuda_cupy(mfab_device):
200195
# https://docs.cupy.dev/en/stable/user_guide/interoperability.html
201196
import cupy as cp
202197
import cupyx.profiler
@@ -285,8 +280,7 @@ def set_to_seven(x):
285280
@pytest.mark.skipif(
286281
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
287282
)
288-
def test_mfab_ops_cuda_pytorch(make_mfab_device):
289-
mfab_device = make_mfab_device()
283+
def test_mfab_ops_cuda_pytorch(mfab_device):
290284
# https://docs.cupy.dev/en/stable/user_guide/interoperability.html#pytorch
291285
import torch
292286

@@ -305,8 +299,8 @@ def test_mfab_ops_cuda_pytorch(make_mfab_device):
305299
@pytest.mark.skipif(
306300
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
307301
)
308-
def test_mfab_ops_cuda_cuml(make_mfab_device):
309-
mfab_device = make_mfab_device() # noqa
302+
def test_mfab_ops_cuda_cuml(mfab_device):
303+
pass
310304
# https://github.com/rapidsai/cuml
311305
# https://github.com/rapidsai/cudf
312306
# maybe better for particles as a dataframe test
@@ -322,47 +316,55 @@ def test_mfab_ops_cuda_cuml(make_mfab_device):
322316
@pytest.mark.skipif(
323317
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
324318
)
325-
def test_mfab_dtoh_copy(make_mfab_device):
326-
mfab_device = make_mfab_device()
327-
328-
mfab_host = amr.MultiFab(
329-
mfab_device.box_array(),
330-
mfab_device.dm(),
331-
mfab_device.n_comp(),
332-
mfab_device.n_grow_vect(),
333-
amr.MFInfo().set_arena(amr.The_Pinned_Arena()),
334-
)
335-
mfab_host.set_val(42.0)
336-
337-
amr.dtoh_memcpy(mfab_host, mfab_device)
338-
339-
# assert all are 0.0 on host
340-
host_min = mfab_host.min(0)
341-
host_max = mfab_host.max(0)
342-
assert host_min == host_max
343-
assert host_max == 0.0
344-
345-
dev_val = 11.0
346-
mfab_host.set_val(dev_val)
347-
amr.dtoh_memcpy(mfab_device, mfab_host)
348-
349-
# assert all are 11.0 on device
350-
for n in range(mfab_device.n_comp()):
351-
assert mfab_device.min(comp=n) == dev_val
352-
assert mfab_device.max(comp=n) == dev_val
353-
354-
# numpy bindings (w/ copy)
355-
local_boxes_host = mfab_device.to_numpy(copy=True)
356-
assert max([np.max(box) for box in local_boxes_host]) == dev_val
357-
358-
# numpy bindings (w/ copy)
359-
for mfi in mfab_device:
360-
marr = mfab_device.array(mfi).to_numpy(copy=True)
361-
assert np.min(marr) >= dev_val
362-
assert np.max(marr) <= dev_val
319+
def test_mfab_dtoh_copy(mfab_device):
320+
class MfabPinnedContextManager:
321+
def __enter__(self):
322+
self.mfab = amr.MultiFab(
323+
mfab_device.box_array(),
324+
mfab_device.dm(),
325+
mfab_device.n_comp(),
326+
mfab_device.n_grow_vect(),
327+
amr.MFInfo().set_arena(amr.The_Pinned_Arena()),
328+
)
329+
return self.mfab
330+
331+
def __exit__(self, exc_type, exc_value, traceback):
332+
self.mfab.clear()
333+
del self.mfab
334+
335+
with MfabPinnedContextManager() as mfab_host:
336+
mfab_host.set_val(42.0)
337+
338+
amr.dtoh_memcpy(mfab_host, mfab_device)
339+
340+
# assert all are 0.0 on host
341+
host_min = mfab_host.min(0)
342+
host_max = mfab_host.max(0)
343+
assert host_min == host_max
344+
assert host_max == 0.0
345+
346+
dev_val = 11.0
347+
mfab_host.set_val(dev_val)
348+
amr.htod_memcpy(mfab_device, mfab_host)
349+
350+
# assert all are 11.0 on device
351+
for n in range(mfab_device.n_comp()):
352+
assert mfab_device.min(comp=n) == dev_val
353+
assert mfab_device.max(comp=n) == dev_val
354+
355+
# numpy bindings (w/ copy)
356+
local_boxes_host = mfab_device.to_numpy(copy=True)
357+
assert max([np.max(box) for box in local_boxes_host]) == dev_val
358+
del local_boxes_host
359+
360+
# numpy bindings (w/ copy)
361+
for mfi in mfab_device:
362+
marr = mfab_device.array(mfi).to_numpy(copy=True)
363+
assert np.min(marr) >= dev_val
364+
assert np.max(marr) <= dev_val
363365

364-
# cupy bindings (w/o copy)
365-
import cupy as cp
366+
# cupy bindings (w/o copy)
367+
import cupy as cp
366368

367-
local_boxes_device = mfab_device.to_cupy()
368-
assert max([cp.max(box) for box in local_boxes_device]) == dev_val
369+
local_boxes_device = mfab_device.to_cupy()
370+
assert max([cp.max(box) for box in local_boxes_device]) == dev_val

0 commit comments

Comments
 (0)