Skip to content

Commit f5a8b0a

Browse files
committed
Add zenodo test for remove_stripe_fw
1 parent 2593790 commit f5a8b0a

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

httomolibgpu/prep/stripe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ def remove_stripe_fw(data: cp.ndarray, sigma: float=1, wname: str='sym16', level
559559
sli_shape = new_sli_shape
560560
for c in cc:
561561
mem_stack.free(np.prod(c) * np.float32().itemsize)
562+
mem_stack.malloc(np.prod(data) * np.float32().itemsize)
562563
mem_stack.free(np.prod(sli_shape) * np.float32().itemsize)
563564
return
564565

@@ -585,8 +586,7 @@ def remove_stripe_fw(data: cp.ndarray, sigma: float=1, wname: str='sym16', level
585586

586587
data = sli[:, 0, (nproj_pad - nproj)//2:(nproj_pad + nproj) // 2, :ni]
587588
data = data.swapaxes(0, 1)
588-
589-
return data
589+
return cp.ascontiguousarray(data)
590590

591591

592592
######## Optimized version for Vo-all ring removal in tomopy########

zenodo-tests/test_prep/test_stripe.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from numpy.testing import assert_allclose
66
from httomolibgpu.prep.stripe import (
77
remove_stripe_based_sorting,
8+
remove_stripe_fw,
89
remove_stripe_ti,
910
remove_all_stripe,
1011
raven_filter,
@@ -98,6 +99,50 @@ def test_remove_stripe_ti_i12_dataset4(
9899
assert output.flags.c_contiguous
99100

100101

102+
@pytest.mark.parametrize(
103+
"dataset_fixture, sigma_val, level, norm_res_expected",
104+
[
105+
(
106+
"i12_dataset4",
107+
0.01,
108+
7,
109+
52.4856,
110+
),
111+
(
112+
"i12_dataset4",
113+
0.3,
114+
5,
115+
53.7807,
116+
),
117+
(
118+
"i12_dataset4",
119+
1.0,
120+
10,
121+
262.2167,
122+
),
123+
],
124+
ids=["case_001", "case_003", "case_006"],
125+
)
126+
def test_remove_stripe_fw_i12_dataset4(
127+
request, dataset_fixture, sigma_val, level, norm_res_expected
128+
):
129+
dataset = request.getfixturevalue(dataset_fixture)
130+
data_normalised = normalize(dataset[0], dataset[2], dataset[3], minus_log=True)
131+
132+
del dataset
133+
force_clean_gpu_memory()
134+
135+
output = remove_stripe_fw(cp.copy(data_normalised), sigma=sigma_val, level=level)
136+
137+
residual_calc = data_normalised - output
138+
norm_res = cp.linalg.norm(residual_calc.flatten())
139+
140+
assert isclose(norm_res, norm_res_expected, abs_tol=10**-4)
141+
142+
assert output.dtype == np.float32
143+
assert output.flags.c_contiguous
144+
145+
101146
@pytest.mark.parametrize(
102147
"dataset_fixture, snr_val, la_size_val, sm_size_val, norm_res_expected",
103148
[

0 commit comments

Comments
 (0)