Skip to content

Commit 11ec6e0

Browse files
committed
adding cuda kernel for nansinfs correction and correction to zero calculator
1 parent f3de1a1 commit 11ec6e0

File tree

2 files changed

+63
-18
lines changed

2 files changed

+63
-18
lines changed

httomolibgpu/misc/supp_func.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,18 @@
2626
cp = cupywrapper.cp
2727
cupy_run = cupywrapper.cupy_run
2828

29+
import numpy as np
30+
31+
from unittest.mock import Mock
32+
33+
if cupy_run:
34+
from httomolibgpu.cuda_kernels import load_cuda_module
35+
else:
36+
load_cuda_module = Mock()
37+
2938

3039
def _naninfs_check(
3140
data: cp.ndarray,
32-
correction: bool = True,
3341
verbosity: bool = True,
3442
method_name: Optional[str] = None,
3543
) -> cp.ndarray:
@@ -40,8 +48,6 @@ def _naninfs_check(
4048
----------
4149
data : cp.ndarray
4250
Input CuPy or Numpy array either float32 or uint16 data type.
43-
correction : bool
44-
If correction is enabled then Inf's and NaN's will be replaced by zeros.
4551
verbosity : bool
4652
If enabled, then the printing of the warning happens when data contains infs or nans
4753
method_name : str, optional.
@@ -52,21 +58,53 @@ def _naninfs_check(
5258
ndarray
5359
Uncorrected or corrected (nans and infs converted to zeros) input array.
5460
"""
61+
present_nans_infs_b = False
62+
5563
if cupy_run:
5664
xp = cp.get_array_module(data)
5765
else:
5866
import numpy as xp
5967

60-
if not xp.all(xp.isfinite(data)):
68+
if xp.__name__ == "cupy":
69+
input_type = data.dtype
70+
if len(data.shape) == 2:
71+
dy, dx = data.shape
72+
dz = 1
73+
else:
74+
dz, dy, dx = data.shape
75+
76+
present_nans_infs = cp.zeros(shape=(1)).astype(cp.uint8)
77+
78+
block_x = 128
79+
# setting grid/block parameters
80+
block_dims = (block_x, 1, 1)
81+
grid_x = (dx + block_x - 1) // block_x
82+
grid_y = dy
83+
grid_z = dz
84+
grid_dims = (grid_x, grid_y, grid_z)
85+
params = (data, dz, dy, dx, present_nans_infs)
86+
87+
kernel_args = "remove_nan_inf<{0}>".format(
88+
"float" if input_type == "float32" else "unsigned short"
89+
)
90+
91+
module = load_cuda_module("remove_nan_inf", name_expressions=[kernel_args])
92+
remove_nan_inf_kernel = module.get_function(kernel_args)
93+
remove_nan_inf_kernel(grid_dims, block_dims, params)
94+
95+
if present_nans_infs[0].get() == 1:
96+
present_nans_infs_b = True
97+
else:
98+
if not np.all(np.isfinite(data)):
99+
present_nans_infs_b = True
100+
np.nan_to_num(data, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
101+
102+
if present_nans_infs_b:
61103
if verbosity:
62104
print(
63-
f"Warning!!! Input data to method: {method_name} contains Inf's or/and NaN's."
64-
)
65-
if correction:
66-
print(
67-
"Inf's or/and NaN's will be corrected to finite integers (zeros). It is advisable to check the correctness of the input."
105+
f"Warning!!! Input data to method: {method_name} contains Inf's or/and NaN's. This will be corrected but it sometimes recommended to check the validity of input to the method."
68106
)
69-
xp.nan_to_num(data, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
107+
70108
return data
71109

72110

@@ -100,12 +138,13 @@ def _zeros_check(
100138
else:
101139
import numpy as xp
102140

103-
warning_zeros = False
104-
zero_elements_total = int(xp.count_nonzero(data == 0))
105-
106141
nonzero_elements_total = 1
107142
for tot_elements_mult in data.shape:
108143
nonzero_elements_total *= tot_elements_mult
144+
145+
warning_zeros = False
146+
zero_elements_total = nonzero_elements_total - int(xp.count_nonzero(data))
147+
109148
if (zero_elements_total / nonzero_elements_total) * 100 >= percentage_threshold:
110149
warning_zeros = True
111150
if verbosity:
@@ -140,9 +179,7 @@ def data_checker(
140179
Returns corrected or not data array.
141180
"""
142181

143-
data = _naninfs_check(
144-
data, correction=True, verbosity=verbosity, method_name=method_name
145-
)
182+
data = _naninfs_check(data, verbosity=verbosity, method_name=method_name)
146183

147184
_zeros_check(
148185
data,

tests/test_misc/test_supp_func.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
def test_naninfs_check1():
13-
_data_input = cp.ones(shape=(10, 10, 10)) * 100
13+
_data_input = cp.ones(shape=(10, 10, 10)).astype(cp.float32) * 100
1414
_data_output = _naninfs_check(_data_input.copy())
1515

1616
assert_equal(
@@ -22,7 +22,7 @@ def test_naninfs_check1():
2222

2323

2424
def test_naninfs_check1_numpy():
25-
_data_input = np.ones(shape=(10, 10, 10)) * 100
25+
_data_input = np.ones(shape=(10, 10, 10)).astype(np.float32) * 100
2626
_data_output = _naninfs_check(_data_input.copy())
2727

2828
assert_equal(
@@ -95,6 +95,14 @@ def test_naninfs_check3_numpy():
9595
assert _data_output.shape == (10, 10, 10)
9696

9797

98+
def test_naninfs_check4():
99+
_data_input = cp.ones(shape=(10, 10, 10)).astype(cp.uint8) * 100
100+
_data_output = _naninfs_check(_data_input.copy())
101+
102+
assert _data_output.dtype == _data_input.dtype
103+
assert _data_output.shape == (10, 10, 10)
104+
105+
98106
def test_naninfs_check4_numpy():
99107
_data_input = np.ones(shape=(10, 10, 10), dtype=np.uint8) * 100
100108
_data_output = _naninfs_check(_data_input.copy())

0 commit comments

Comments
 (0)