66
77from httomolibgpu .prep .normalize import normalize
88from httomolibgpu .prep .stripe import (
9+ DeviceMemStack ,
910 remove_stripe_based_sorting ,
1011 remove_stripe_ti ,
1112 remove_stripe_fw ,
@@ -35,6 +36,23 @@ def test_remove_stripe_ti_on_data(data, flats, darks):
3536 assert data_after_stripe_removal .dtype == np .float32
3637
3738
39+ class MaxMemoryHook (cp .cuda .MemoryHook ):
40+ def __init__ (self , initial = 0 ):
41+ self .max_mem = initial
42+ self .current = initial
43+
44+ def malloc_postprocess (
45+ self , device_id : int , size : int , mem_size : int , mem_ptr : int , pmem_id : int
46+ ):
47+ self .current += mem_size
48+ self .max_mem = max (self .max_mem , self .current )
49+
50+ def free_postprocess (
51+ self , device_id : int , mem_size : int , mem_ptr : int , pmem_id : int
52+ ):
53+ self .current -= mem_size
54+
55+
3856def test_remove_stripe_fw_on_data (data , flats , darks ):
3957 # --- testing the CuPy implementation from TomoCupy ---#
4058 data = normalize (data , flats , darks , cutoff = 10 , minus_log = True )
@@ -54,6 +72,42 @@ def test_remove_stripe_fw_on_data(data, flats, darks):
5472 assert data_after_stripe_removal .dtype == np .float32
5573
5674
75+ @pytest .fixture
76+ def ensure_clean_memory ():
77+ cp .get_default_memory_pool ().free_all_blocks ()
78+ cp .get_default_pinned_memory_pool ().free_all_blocks ()
79+ cache = cp .fft .config .get_plan_cache ()
80+ cache .clear ()
81+ yield None
82+ cp .get_default_memory_pool ().free_all_blocks ()
83+ cp .get_default_pinned_memory_pool ().free_all_blocks ()
84+ cache = cp .fft .config .get_plan_cache ()
85+ cache .clear ()
86+
87+
88+ @pytest .mark .parametrize ("slices" , [55 , 80 ])
89+ @pytest .mark .parametrize ("level" , [1 , 3 , 7 , 11 ])
90+ @pytest .mark .parametrize ("dim_x" , [128 , 140 ])
91+ def test_remove_stripe_fw_mem_stack (slices , level , dim_x , ensure_clean_memory ):
92+ dim_y = 159
93+ data = cp .random .random_sample ((slices , dim_x , dim_y ), dtype = np .float32 )
94+ hook = MaxMemoryHook ()
95+ with hook :
96+ remove_stripe_fw (cp .copy (data ), level = level )
97+ actual_mem_peak = hook .max_mem
98+
99+ hook = MaxMemoryHook ()
100+ mem_stack = DeviceMemStack ()
101+ with hook :
102+ remove_stripe_fw (data .shape , level = level , mem_stack = mem_stack )
103+ assert hook .max_mem == 0
104+ estimated_mem_peak = mem_stack .highwater
105+
106+ # assert actual_mem_peak == estimated_mem_peak
107+ assert actual_mem_peak * 0.99 <= estimated_mem_peak
108+ assert estimated_mem_peak <= actual_mem_peak * 1.01
109+
110+
57111@pytest .mark .parametrize ("angles" , [180 , 181 ])
58112@pytest .mark .parametrize ("det_x" , [11 , 18 ])
59113@pytest .mark .parametrize ("det_y" , [5 , 7 , 8 ])
0 commit comments