Skip to content

Commit 2593790

Browse files
committed
Test memory estimation with stack
1 parent b1b51a2 commit 2593790

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

tests/test_prep/test_stripe.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from httomolibgpu.prep.normalize import normalize
88
from httomolibgpu.prep.stripe import (
9+
DeviceMemStack,
910
remove_stripe_based_sorting,
1011
remove_stripe_ti,
1112
remove_stripe_fw,
@@ -35,6 +36,23 @@ def test_remove_stripe_ti_on_data(data, flats, darks):
3536
assert data_after_stripe_removal.dtype == np.float32
3637

3738

39+
class MaxMemoryHook(cp.cuda.MemoryHook):
40+
def __init__(self, initial=0):
41+
self.max_mem = initial
42+
self.current = initial
43+
44+
def malloc_postprocess(
45+
self, device_id: int, size: int, mem_size: int, mem_ptr: int, pmem_id: int
46+
):
47+
self.current += mem_size
48+
self.max_mem = max(self.max_mem, self.current)
49+
50+
def free_postprocess(
51+
self, device_id: int, mem_size: int, mem_ptr: int, pmem_id: int
52+
):
53+
self.current -= mem_size
54+
55+
3856
def test_remove_stripe_fw_on_data(data, flats, darks):
3957
# --- testing the CuPy implementation from TomoCupy ---#
4058
data = normalize(data, flats, darks, cutoff=10, minus_log=True)
@@ -54,6 +72,42 @@ def test_remove_stripe_fw_on_data(data, flats, darks):
5472
assert data_after_stripe_removal.dtype == np.float32
5573

5674

75+
@pytest.fixture
76+
def ensure_clean_memory():
77+
cp.get_default_memory_pool().free_all_blocks()
78+
cp.get_default_pinned_memory_pool().free_all_blocks()
79+
cache = cp.fft.config.get_plan_cache()
80+
cache.clear()
81+
yield None
82+
cp.get_default_memory_pool().free_all_blocks()
83+
cp.get_default_pinned_memory_pool().free_all_blocks()
84+
cache = cp.fft.config.get_plan_cache()
85+
cache.clear()
86+
87+
88+
@pytest.mark.parametrize("slices", [55, 80])
89+
@pytest.mark.parametrize("level", [1, 3, 7, 11])
90+
@pytest.mark.parametrize("dim_x", [128, 140])
91+
def test_remove_stripe_fw_mem_stack(slices, level, dim_x, ensure_clean_memory):
92+
dim_y = 159
93+
data = cp.random.random_sample((slices, dim_x, dim_y), dtype=np.float32)
94+
hook = MaxMemoryHook()
95+
with hook:
96+
remove_stripe_fw(cp.copy(data), level=level)
97+
actual_mem_peak = hook.max_mem
98+
99+
hook = MaxMemoryHook()
100+
mem_stack = DeviceMemStack()
101+
with hook:
102+
remove_stripe_fw(data.shape, level=level, mem_stack=mem_stack)
103+
assert hook.max_mem == 0
104+
estimated_mem_peak = mem_stack.highwater
105+
106+
# assert actual_mem_peak == estimated_mem_peak
107+
assert actual_mem_peak * 0.99 <= estimated_mem_peak
108+
assert estimated_mem_peak <= actual_mem_peak * 1.01
109+
110+
57111
@pytest.mark.parametrize("angles", [180, 181])
58112
@pytest.mark.parametrize("det_x", [11, 18])
59113
@pytest.mark.parametrize("det_y", [5, 7, 8])

0 commit comments

Comments
 (0)