1414from numpy .testing import assert_allclose
1515from .stripe_cpu_reference import raven_filter_numpy
1616
17+
1718def test_remove_stripe_ti_on_data (data , flats , darks ):
1819 # --- testing the CuPy implementation from TomoCupy ---#
1920 data = normalize (data , flats , darks , cutoff = 10 , minus_log = True )
@@ -53,6 +54,7 @@ def test_remove_stripe_ti_on_data(data, flats, darks):
5354# np.median(corrected_data), np.median(corrected_host_data), rtol=1e-6
5455# )
5556
57+
5658def test_stripe_removal_sorting_cupy (data , flats , darks ):
5759 # --- testing the CuPy port of TomoPy's implementation ---#
5860 data = normalize (data , flats , darks , cutoff = 10 , minus_log = True )
@@ -67,6 +69,7 @@ def test_stripe_removal_sorting_cupy(data, flats, darks):
6769 assert corrected_data .dtype == np .float32
6870 assert corrected_data .flags .c_contiguous
6971
72+
7073def test_stripe_raven_cupy (data , flats , darks ):
7174 # --- testing the CuPy port of TomoPy's implementation ---#
7275
@@ -82,13 +85,16 @@ def test_stripe_raven_cupy(data, flats, darks):
8285 assert data_after_raven_gpu .dtype == np .float32
8386 assert data_after_raven_gpu .shape == data_after_raven_cpu .shape
8487
88+
8589@pytest .mark .parametrize ("uvalue" , [20 , 50 , 100 ])
8690@pytest .mark .parametrize ("nvalue" , [2 , 4 , 6 ])
8791@pytest .mark .parametrize ("vvalue" , [2 , 4 ])
8892@pytest .mark .parametrize ("pad_x" , [0 , 10 , 20 ])
8993@pytest .mark .parametrize ("pad_y" , [0 , 10 , 20 ])
9094@cp .testing .numpy_cupy_allclose (rtol = 0 , atol = 3e-01 )
91- def test_stripe_raven_parameters_cupy (ensure_clean_memory , xp , uvalue , nvalue , vvalue , pad_x , pad_y ):
95+ def test_stripe_raven_parameters_cupy (
96+ ensure_clean_memory , xp , uvalue , nvalue , vvalue , pad_x , pad_y
97+ ):
9298 # because it's random, we explicitly seed and use numpy only, to match the data
9399 np .random .seed (12345 )
94100 data = np .random .random_sample (size = (256 , 5 , 512 )).astype (np .float32 ) * 2.0 + 0.001
@@ -97,14 +103,15 @@ def test_stripe_raven_parameters_cupy(ensure_clean_memory, xp, uvalue, nvalue, v
97103 if xp .__name__ == "numpy" :
98104 results = raven_filter_numpy (
99105 data , uvalue = uvalue , nvalue = nvalue , vvalue = vvalue , pad_x = pad_x , pad_y = pad_y
100- ).astype (np .float32 )
106+ ).astype (np .float32 )
101107 else :
102108 results = raven_filter (
103109 data , uvalue = uvalue , nvalue = nvalue , vvalue = vvalue , pad_x = pad_x , pad_y = pad_y
104110 ).get ()
105111
106112 return xp .asarray (results )
107113
114+
108115@pytest .mark .perf
109116def test_stripe_removal_sorting_cupy_performance (ensure_clean_memory ):
110117 data_host = (
@@ -154,6 +161,7 @@ def test_remove_stripe_ti_performance(ensure_clean_memory):
154161
155162 assert "performance in ms" == duration_ms
156163
164+
157165@pytest .mark .perf
158166def test_raven_filter_performance (ensure_clean_memory ):
159167 data_host = (
@@ -178,6 +186,7 @@ def test_raven_filter_performance(ensure_clean_memory):
178186
179187 assert "performance in ms" == duration_ms
180188
189+
181190def test_remove_all_stripe_on_data (data , flats , darks ):
182191 # --- testing the CuPy implementation from TomoCupy ---#
183192 data = normalize (data , flats , darks , cutoff = 10 , minus_log = True )
@@ -190,8 +199,12 @@ def test_remove_all_stripe_on_data(data, flats, darks):
190199 )
191200 assert_allclose (np .median (data_after_stripe_removal ), 0.015338 , rtol = 1e-04 )
192201 assert_allclose (np .max (data_after_stripe_removal ), 2.298123 , rtol = 1e-05 )
193- assert_allclose (np .median (data_after_stripe_removal , axis = (1 , 2 )).sum (), 2.788046 , rtol = 1e-6 )
194- assert_allclose (np .median (data_after_stripe_removal , axis = (0 , 1 )).sum (), 28.661312 , rtol = 1e-6 )
202+ assert_allclose (
203+ np .median (data_after_stripe_removal , axis = (1 , 2 )).sum (), 2.788046 , rtol = 1e-6
204+ )
205+ assert_allclose (
206+ np .median (data_after_stripe_removal , axis = (0 , 1 )).sum (), 28.661312 , rtol = 1e-6
207+ )
195208
196209 data = None #: free up GPU memory
197210 # make sure the output is float32
0 commit comments