Skip to content

Commit 30d78c9

Browse files
committed
modification to the supplementary function
1 parent 3302b09 commit 30d78c9

File tree

2 files changed

+89
-3
lines changed

2 files changed

+89
-3
lines changed

httomolibgpu/misc/supp_func.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _naninfs_check(
3939
Parameters
4040
----------
4141
data : cp.ndarray
42-
Input CuPy array either float32 or uint16 data type.
42+
Input CuPy or Numpy array either float32 or uint16 data type.
4343
correction : bool
4444
If correction is enabled then Inf's and NaN's will be replaced by zeros.
4545
verbosity : bool
@@ -50,7 +50,7 @@ def _naninfs_check(
5050
Returns
5151
-------
5252
ndarray
53-
Corrected (or not) CuPy array.
53+
Uncorrected or corrected (nans and infs converted to zeros) input array.
5454
"""
5555
if cupy_run:
5656
xp = cp.get_array_module(data)
@@ -82,7 +82,7 @@ def _zeros_check(
8282
Parameters
8383
----------
8484
data : cp.ndarray
85-
Input CuPy array either float32 or uint16 data type.
85+
Input CuPy or Numpy array either float32 or uint16 data type.
8686
verbosity : bool
8787
If enabled, then the printing of the warning happens when data contains infs or nans
8888
percentage_threshold: float:
@@ -111,3 +111,41 @@ def _zeros_check(
111111
)
112112

113113
return warning_zeros
114+
115+
116+
def data_checker(
117+
data: cp.ndarray,
118+
verbosity: bool = True,
119+
method_name: Optional[str] = None,
120+
) -> bool:
121+
"""
122+
Function that performs the variety of checks on input data, in some cases also correct the data and prints warnings.
123+
Currently it checks for: the presence of infs and nans in data; the number of zero elements.
124+
125+
Parameters
126+
----------
127+
data : xp.ndarray
128+
Input CuPy or Numpy array either float32 or uint16 data type.
129+
verbosity : bool
130+
If enabled, then the printing of the warning happens when data contains infs or nans
131+
method_name : str, optional.
132+
Method's name for which input data is tested.
133+
134+
Returns
135+
-------
136+
cp.ndarray
137+
Returns corrected or not data array
138+
"""
139+
140+
data = _naninfs_check(
141+
data, correction=True, verbosity=verbosity, method_name=method_name
142+
)
143+
144+
_zeros_check(
145+
data,
146+
verbosity=verbosity,
147+
percentage_threshold=50,
148+
method_name=method_name,
149+
)
150+
151+
return data

tests/test_misc/test_supp_func.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from httomolibgpu.misc.supp_func import (
55
_naninfs_check,
66
_zeros_check,
7+
data_checker,
78
)
89
from numpy.testing import assert_allclose, assert_equal
910

@@ -140,3 +141,50 @@ def test_zeros_check3_numpy():
140141
warning_zeros = _zeros_check(_data_input.copy())
141142

142143
assert warning_zeros == False
144+
145+
146+
def test_data_checker_numpy():
147+
_data_input = np.ones(shape=(10, 10, 10), dtype=np.float32)
148+
_data_input[1, 1, 1] = -np.inf
149+
_data_input[1, 1, 2] = np.inf
150+
_data_input[1, 1, 3] = np.nan
151+
152+
_data_output = data_checker(_data_input.copy())
153+
154+
assert_equal(
155+
_data_output[1, 1, 1],
156+
0,
157+
)
158+
assert_equal(
159+
_data_output[1, 1, 2],
160+
0,
161+
)
162+
assert_equal(
163+
_data_output[1, 1, 3],
164+
0,
165+
)
166+
assert _data_output.dtype == _data_input.dtype
167+
assert _data_output.shape == (10, 10, 10)
168+
169+
def test_data_checker():
170+
_data_input = cp.ones(shape=(10, 10, 10), dtype=cp.float32)*100
171+
_data_input[1, 1, 1] = -cp.inf
172+
_data_input[1, 1, 2] = cp.inf
173+
_data_input[1, 1, 3] = cp.nan
174+
175+
_data_output = data_checker(_data_input.copy())
176+
177+
assert_equal(
178+
_data_output[1, 1, 1].get(),
179+
0.0,
180+
)
181+
assert_equal(
182+
_data_output[1, 1, 2].get(),
183+
0.0,
184+
)
185+
assert_equal(
186+
_data_output[1, 1, 3].get(),
187+
0.0,
188+
)
189+
assert _data_output.dtype == _data_input.dtype
190+
assert _data_output.shape == (10, 10, 10)

0 commit comments

Comments
 (0)