Skip to content

Commit d95e59d

Browse files
committed
fixing few suggestions, applying black formatting
1 parent 44442f9 commit d95e59d

File tree

11 files changed

+52
-37
lines changed

11 files changed

+52
-37
lines changed

httomolibgpu/misc/corr.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def median_filter(
8282
else:
8383
raise ValueError("The input array must be a 3D array")
8484

85-
data = data_checker(data, verbosity=True, method_name="median_filter_or_remove_outlier")
85+
data = data_checker(
86+
data, verbosity=True, method_name="median_filter_or_remove_outlier"
87+
)
8688

8789
if kernel_size not in [3, 5, 7, 9, 11, 13]:
8890
raise ValueError("Please select a correct kernel size: 3, 5, 7, 9, 11, 13")

httomolibgpu/misc/denoise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def total_variation_ROF(
8282
If the input array is not float32 data type.
8383
"""
8484

85-
data = data_checker(data,verbosity=True,method_name="total_variation_ROF")
85+
data = data_checker(data, verbosity=True, method_name="total_variation_ROF")
8686

8787
return ROF_TV(
8888
data, regularisation_parameter, iterations, time_marching_parameter, gpu_id
@@ -129,7 +129,7 @@ def total_variation_PD(
129129
If the input array is not float32 data type.
130130
"""
131131

132-
data_checker(data,verbosity=True,method_name="total_variation_PD")
132+
data_checker(data, verbosity=True, method_name="total_variation_PD")
133133

134134
methodTV = 0
135135
if not isotropic:

httomolibgpu/misc/supp_func.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,14 @@ def _zeros_check(
7777
method_name: Optional[str] = None,
7878
) -> bool:
7979
"""
80-
This function finds all zeros present in the data. If the amount of zeros is larger than percentage_threshold it prints the warning.
80+
This function finds all zeros present in the data. If the amount of zeros is larger than percentage_threshold it prints the warning.
8181
8282
Parameters
8383
----------
8484
data : cp.ndarray
85-
Input CuPy or Numpy array either float32 or uint16 data type.
85+
Input CuPy or Numpy array.
8686
verbosity : bool
87-
If enabled, then the printing of the warning happens when data contains infs or nans
87+
If enabled, then the printing of the warning happens when data contains infs or nans.
8888
percentage_threshold: float:
8989
If the number of zeros in input data is more than the percentage of all data points, then print the data warning
9090
method_name : str, optional.
@@ -103,7 +103,7 @@ def _zeros_check(
103103
warning_zeros = False
104104
zero_elements_total = int(xp.count_nonzero(data == 0))
105105

106-
nonzero_elements_total = 1
106+
nonzero_elements_total = 1
107107
for tot_elements_mult in data.shape:
108108
nonzero_elements_total *= tot_elements_mult
109109
if (zero_elements_total / nonzero_elements_total) * 100 >= percentage_threshold:
@@ -123,21 +123,21 @@ def data_checker(
123123
) -> bool:
124124
"""
125125
Function that performs the variety of checks on input data, in some cases also correct the data and prints warnings.
126-
Currently it checks for: the presence of infs and nans in data; the number of zero elements.
126+
Currently it checks for: the presence of infs and nans in data; the number of zero elements.
127127
128128
Parameters
129129
----------
130130
data : xp.ndarray
131-
Input CuPy or Numpy array either float32 or uint16 data type.
131+
Input CuPy or Numpy array.
132132
verbosity : bool
133-
If enabled, then the printing of the warning happens when data contains infs or nans
133+
If enabled, then the printing of the warning happens when data contains infs or nans.
134134
method_name : str, optional.
135135
Method's name for which input data is tested.
136136
137137
Returns
138138
-------
139139
cp.ndarray
140-
Returns corrected or not data array
140+
Returns corrected or not data array.
141141
"""
142142

143143
data = _naninfs_check(
@@ -151,4 +151,4 @@ def data_checker(
151151
method_name=method_name,
152152
)
153153

154-
return data
154+
return data

httomolibgpu/prep/alignment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@ def distortion_correction_proj_discorpy(
8888
if len(data.shape) == 2:
8989
data = cp.expand_dims(data, axis=0)
9090

91-
data = data_checker(data, verbosity=True, method_name="distortion_correction_proj_discorpy")
91+
data = data_checker(
92+
data, verbosity=True, method_name="distortion_correction_proj_discorpy"
93+
)
9294

9395
# Get info from metadata txt file
9496
xcenter, ycenter, list_fact = _load_metadata_txt(metadata_path)

httomolibgpu/prep/normalize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _check_valid_input_normalise(data, flats, darks) -> None:
145145
flats = flats[cp.newaxis, :, :]
146146
if darks.ndim == 2:
147147
darks = darks[cp.newaxis, :, :]
148-
149-
data_checker(data,verbosity=True,method_name="normalize_data")
150-
data_checker(flats,verbosity=True,method_name="normalize_flats")
151-
data_checker(darks,verbosity=True,method_name="normalize_darks")
148+
149+
data_checker(data, verbosity=True, method_name="normalize_data")
150+
data_checker(flats, verbosity=True, method_name="normalize_flats")
151+
data_checker(darks, verbosity=True, method_name="normalize_darks")

httomolibgpu/prep/phase.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def paganin_filter_tomopy(
300300
f"Invalid number of dimensions in data: {tomo.ndim},"
301301
" please provide a stack of 2D projections."
302302
)
303-
303+
304304
tomo = data_checker(tomo, verbosity=True, method_name="paganin_filter_tomopy")
305305

306306
dz_orig, dy_orig, dx_orig = tomo.shape

httomolibgpu/prep/stripe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def remove_stripe_ti(
144144
_, _, dx_orig = data.shape
145145
if (dx_orig % 2) != 0:
146146
# the horizontal detector size is odd, data needs to be padded/cropped, for now raising the error
147-
raise ValueError("The horizontal detector size must be even")
147+
raise ValueError("The horizontal detector size must be even")
148148

149149
gamma = beta * ((1 - beta) / (1 + beta)) ** cp.abs(
150150
cp.fft.fftfreq(data.shape[-1]) * data.shape[-1]
@@ -387,7 +387,7 @@ def raven_filter(
387387
"""
388388
if data.dtype != cp.float32:
389389
raise ValueError("The input data should be float32 data type")
390-
390+
391391
data = data_checker(data, verbosity=True, method_name="raven_filter")
392392

393393
# Padding of the sinogram

httomolibgpu/recon/rotation.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def find_center_vo(
112112

113113
data = data_checker(data, verbosity=True, method_name="find_center_vo")
114114

115-
angles_tot, detY_size, detX_size = data.shape
115+
angles_tot, detY_size, detX_size = data.shape
116116

117117
if ind is None:
118118
ind = detY_size // 2 # middle slice index
@@ -459,8 +459,8 @@ def find_center_360(
459459
"""
460460
if data.ndim != 3:
461461
raise ValueError("A 3D array must be provided")
462-
463-
data = data_checker(data, verbosity=True, method_name="find_center_360")
462+
463+
data = data_checker(data, verbosity=True, method_name="find_center_360")
464464

465465
# this method works with a 360-degree sinogram.
466466
if ind is None:
@@ -781,8 +781,8 @@ def find_center_pc(
781781
Rotation axis location.
782782
"""
783783

784-
proj1 = data_checker(proj1, verbosity=True, method_name="find_center_pc")
785-
proj2 = data_checker(proj2, verbosity=True, method_name="find_center_pc")
784+
proj1 = data_checker(proj1, verbosity=True, method_name="find_center_pc")
785+
proj2 = data_checker(proj2, verbosity=True, method_name="find_center_pc")
786786

787787
imgshift = 0.0 if rotc_guess is None else rotc_guess - (proj1.shape[1] - 1.0) / 2.0
788788

@@ -802,4 +802,6 @@ def find_center_pc(
802802
center = (proj1.shape[1] + shiftr[0][1] - 1.0) / 2.0
803803

804804
return np.float32(center + imgshift)
805+
806+
805807
##%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

tests/test_misc/test_supp_func.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
_zeros_check,
77
data_checker,
88
)
9-
from numpy.testing import assert_allclose, assert_equal
10-
11-
eps = 1e-6
9+
from numpy.testing import assert_equal
1210

1311

1412
def test_naninfs_check1():
@@ -36,7 +34,7 @@ def test_naninfs_check1_numpy():
3634

3735

3836
def test_naninfs_check2():
39-
_data_input = cp.ones(shape=(10, 10, 10), dtype=cp.float32) * 100
37+
_data_input = cp.ones(shape=(10, 10, 10)).astype(cp.float32) * 100
4038
_data_input[1, 1, 1] = -cp.inf
4139
_data_input[1, 1, 2] = cp.inf
4240
_data_input[1, 1, 3] = cp.nan
@@ -82,7 +80,7 @@ def test_naninfs_check2_numpy():
8280

8381

8482
def test_naninfs_check3():
85-
_data_input = cp.ones(shape=(10, 10, 10), dtype=cp.uint16) * 100
83+
_data_input = cp.ones(shape=(10, 10, 10)).astype(cp.uint16) * 100
8684
_data_output = _naninfs_check(_data_input.copy())
8785

8886
assert _data_output.dtype == _data_input.dtype
@@ -97,8 +95,16 @@ def test_naninfs_check3_numpy():
9795
assert _data_output.shape == (10, 10, 10)
9896

9997

98+
def test_naninfs_check4_numpy():
99+
_data_input = np.ones(shape=(10, 10, 10), dtype=np.uint8) * 100
100+
_data_output = _naninfs_check(_data_input.copy())
101+
102+
assert _data_output.dtype == _data_input.dtype
103+
assert _data_output.shape == (10, 10, 10)
104+
105+
100106
def test_zeros_check1():
101-
_data_input = cp.ones(shape=(10, 10, 10), dtype=cp.float32) * 100
107+
_data_input = cp.ones(shape=(10, 10, 10)).astype(cp.float32) * 100
102108
warning_zeros = _zeros_check(_data_input.copy())
103109

104110
assert warning_zeros == False
@@ -112,7 +118,7 @@ def test_zeros_check1_numpy():
112118

113119

114120
def test_zeros_check2():
115-
_data_input = cp.ones(shape=(10, 10, 10), dtype=cp.float32) * 100
121+
_data_input = cp.ones(shape=(10, 10, 10)).astype(cp.float32) * 100
116122
_data_input[2:7, :, :] = 0.0
117123
warning_zeros = _zeros_check(_data_input.copy())
118124

@@ -128,7 +134,7 @@ def test_zeros_check2_numpy():
128134

129135

130136
def test_zeros_check3():
131-
_data_input = cp.ones(shape=(10, 10, 10), dtype=cp.float32) * 100
137+
_data_input = cp.ones(shape=(10, 10, 10)).astype(cp.float32) * 100
132138
_data_input[3:7, :, :] = 0.0
133139
warning_zeros = _zeros_check(_data_input.copy())
134140

@@ -144,7 +150,7 @@ def test_zeros_check3_numpy():
144150

145151

146152
def test_data_checker_numpy():
147-
_data_input = np.ones(shape=(10, 10, 10), dtype=np.float32)
153+
_data_input = np.ones(shape=(10, 10, 10)).astype(np.float32)
148154
_data_input[1, 1, 1] = -np.inf
149155
_data_input[1, 1, 2] = np.inf
150156
_data_input[1, 1, 3] = np.nan
@@ -166,8 +172,9 @@ def test_data_checker_numpy():
166172
assert _data_output.dtype == _data_input.dtype
167173
assert _data_output.shape == (10, 10, 10)
168174

175+
169176
def test_data_checker():
170-
_data_input = cp.ones(shape=(10, 10, 10)).astype(cp.float32) * 100.0
177+
_data_input = cp.ones(shape=(10, 10, 10)).astype(cp.float32) * 100.0
171178
_data_input[1, 1, 1] = -cp.inf
172179
_data_input[1, 1, 2] = cp.inf
173180
_data_input[1, 1, 3] = cp.nan
@@ -187,4 +194,4 @@ def test_data_checker():
187194
0.0,
188195
)
189196
assert _data_output.dtype == _data_input.dtype
190-
assert _data_output.shape == (10, 10, 10)
197+
assert _data_output.shape == (10, 10, 10)

tests/test_recon/test_algorithm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def test_reconstruct_FBP_2d_astra(data, flats, darks, ensure_clean_memory):
3434
assert recon_data.dtype == np.float32
3535
assert recon_data.shape == (recon_size, 128, recon_size)
3636

37+
3738
def test_reconstruct_FBP3d_tomobar_1(data, flats, darks, ensure_clean_memory):
3839
recon_data = FBP3d_tomobar(
3940
normalize_cupy(data, flats, darks, cutoff=10, minus_log=True),

0 commit comments

Comments
 (0)