2626cp = cupywrapper .cp
2727cupy_run = cupywrapper .cupy_run
2828
29+ import numpy as np
30+
31+ from unittest .mock import Mock
32+
33+ if cupy_run :
34+ from httomolibgpu .cuda_kernels import load_cuda_module
35+ else :
36+ load_cuda_module = Mock ()
37+
2938
3039def _naninfs_check (
3140 data : cp .ndarray ,
32- correction : bool = True ,
3341 verbosity : bool = True ,
3442 method_name : Optional [str ] = None ,
3543) -> cp .ndarray :
@@ -40,8 +48,6 @@ def _naninfs_check(
4048 ----------
4149 data : cp.ndarray
4250 Input CuPy or Numpy array either float32 or uint16 data type.
43- correction : bool
44- If correction is enabled then Inf's and NaN's will be replaced by zeros.
4551 verbosity : bool
4652 If enabled, then the printing of the warning happens when data contains infs or nans
4753 method_name : str, optional.
@@ -52,21 +58,53 @@ def _naninfs_check(
5258 ndarray
5359 Uncorrected or corrected (nans and infs converted to zeros) input array.
5460 """
61+ present_nans_infs_b = False
62+
5563 if cupy_run :
5664 xp = cp .get_array_module (data )
5765 else :
5866 import numpy as xp
5967
60- if not xp .all (xp .isfinite (data )):
68+ if xp .__name__ == "cupy" :
69+ input_type = data .dtype
70+ if len (data .shape ) == 2 :
71+ dy , dx = data .shape
72+ dz = 1
73+ else :
74+ dz , dy , dx = data .shape
75+
76+ present_nans_infs = cp .zeros (shape = (1 )).astype (cp .uint8 )
77+
78+ block_x = 128
79+ # setting grid/block parameters
80+ block_dims = (block_x , 1 , 1 )
81+ grid_x = (dx + block_x - 1 ) // block_x
82+ grid_y = dy
83+ grid_z = dz
84+ grid_dims = (grid_x , grid_y , grid_z )
85+ params = (data , dz , dy , dx , present_nans_infs )
86+
87+ kernel_args = "remove_nan_inf<{0}>" .format (
88+ "float" if input_type == "float32" else "unsigned short"
89+ )
90+
91+ module = load_cuda_module ("remove_nan_inf" , name_expressions = [kernel_args ])
92+ remove_nan_inf_kernel = module .get_function (kernel_args )
93+ remove_nan_inf_kernel (grid_dims , block_dims , params )
94+
95+ if present_nans_infs [0 ].get () == 1 :
96+ present_nans_infs_b = True
97+ else :
98+ if not np .all (np .isfinite (data )):
99+ present_nans_infs_b = True
100+ np .nan_to_num (data , copy = False , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
101+
102+ if present_nans_infs_b :
61103 if verbosity :
62104 print (
63- f"Warning!!! Input data to method: { method_name } contains Inf's or/and NaN's."
64- )
65- if correction :
66- print (
67- "Inf's or/and NaN's will be corrected to finite integers (zeros). It is advisable to check the correctness of the input."
105+ f"Warning!!! Input data to method: { method_name } contains Inf's or/and NaN's. This will be corrected but it sometimes recommended to check the validity of input to the method."
68106 )
69- xp . nan_to_num ( data , copy = False , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
107+
70108 return data
71109
72110
@@ -100,12 +138,13 @@ def _zeros_check(
100138 else :
101139 import numpy as xp
102140
103- warning_zeros = False
104- zero_elements_total = int (xp .count_nonzero (data == 0 ))
105-
106141 nonzero_elements_total = 1
107142 for tot_elements_mult in data .shape :
108143 nonzero_elements_total *= tot_elements_mult
144+
145+ warning_zeros = False
146+ zero_elements_total = nonzero_elements_total - int (xp .count_nonzero (data ))
147+
109148 if (zero_elements_total / nonzero_elements_total ) * 100 >= percentage_threshold :
110149 warning_zeros = True
111150 if verbosity :
@@ -140,9 +179,7 @@ def data_checker(
140179 Returns corrected or not data array.
141180 """
142181
143- data = _naninfs_check (
144- data , correction = True , verbosity = verbosity , method_name = method_name
145- )
182+ data = _naninfs_check (data , verbosity = verbosity , method_name = method_name )
146183
147184 _zeros_check (
148185 data ,
0 commit comments