Skip to content

Commit 2b2da47

Browse files
committed
Code clean up
1 parent afe6323 commit 2b2da47

File tree

2 files changed

+23
-20
lines changed

2 files changed

+23
-20
lines changed

httomolibgpu/prep/stripe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -413,8 +413,8 @@ def raven_filter(
413413

414414
input_type = sinogram.dtype
415415

416-
if input_type not in ["float32", "float64"]:
417-
raise ValueError("The input data should be either float32 or float64 data type")
416+
if input_type not in ["float32"]:
417+
raise ValueError("The input data should be float32 data type")
418418

419419
# Padding of the sinogram
420420
sinogram = cp.pad(sinogram, ((pad_y, pad_y), (0, 0), (pad_x, pad_x)), mode=pad_method)
@@ -423,12 +423,15 @@ def raven_filter(
423423
fft_data = fft2(sinogram, axes=(0, 2), overwrite_x=True)
424424
fft_data_shifted = fftshift(fft_data, axes=(0, 2))
425425

426+
# Calculation type
427+
calc_type = fft_data_shifted.dtype
428+
426429
# Setup various values for the filter
427430
height, images, width = sinogram.shape
428431

429432
# Set the input type of the kernel
430433
kernel_args = "raven_filter<{0}>".format(
431-
"float" if input_type == "float32" else "double"
434+
"float" if calc_type == "complex64" else "double"
432435
)
433436

434437
# setting grid/block parameters

tests/test_prep/test_stripe.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -155,23 +155,23 @@ def test_raven_filter_performance(ensure_clean_memory):
155155

156156
assert "performance in ms" == duration_ms
157157

158-
@pytest.mark.perf
159-
def test_raven_filter_cpu_performance(ensure_clean_memory):
160-
data_host = (
161-
np.random.random_sample(size=(1801, 5, 2560)).astype(np.float32) * 2.0 + 0.001
162-
)
163-
data = cp.asarray(data_host, dtype=np.float32)
164-
165-
# do a cold run first
166-
raven_filter_cpu(cp.copy(data).get())
167-
168-
start = time.perf_counter_ns()
169-
for _ in range(10):
170-
raven_filter_cpu(cp.copy(data).get())
171-
172-
duration_ms = float(time.perf_counter_ns() - start) * 1e-6 / 10
173-
174-
assert "performance in ms" == duration_ms
158+
# @pytest.mark.perf
159+
# def test_raven_filter_cpu_performance(ensure_clean_memory):
160+
# data_host = (
161+
# np.random.random_sample(size=(1801, 5, 2560)).astype(np.float32) * 2.0 + 0.001
162+
# )
163+
# data = cp.asarray(data_host, dtype=np.float32)
164+
#
165+
# # do a cold run first
166+
# raven_filter_cpu(cp.copy(data).get())
167+
#
168+
# start = time.perf_counter_ns()
169+
# for _ in range(10):
170+
# raven_filter_cpu(cp.copy(data).get())
171+
#
172+
# duration_ms = float(time.perf_counter_ns() - start) * 1e-6 / 10
173+
#
174+
# assert "performance in ms" == duration_ms
175175

176176

177177
def test_remove_all_stripe_on_data(data, flats, darks):

0 commit comments

Comments
 (0)