@@ -157,7 +157,7 @@ def __init__(
157
157
158
158
for idx , rule in enumerate (self .split_rules ):
159
159
if rule is ContinuousSplitRule :
160
- self .X [:, idx ] = jitter_duplicated (self .X [:, idx ], np .std (self .X [:, idx ]))
160
+ self .X [:, idx ] = jitter_duplicated (self .X [:, idx ], np .nanstd (self .X [:, idx ]))
161
161
162
162
init_mean = self .bart .Y .mean ()
163
163
self .num_observations = self .X .shape [0 ]
@@ -700,7 +700,7 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[
700
700
if are_whole_number (array ):
701
701
seen = []
702
702
for idx , num in enumerate (array ):
703
- if num in seen :
703
+ if num in seen and not np . isnan ( num ) :
704
704
array [idx ] = num + np .random .normal (0 , std / 12 )
705
705
else :
706
706
seen .append (num )
@@ -711,8 +711,7 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[
711
711
@njit
712
712
def are_whole_number (array : npt .NDArray [np .float_ ]) -> np .bool_ :
713
713
"""Check if all values in array are whole numbers"""
714
- new_array = np .mod (array , 1 )
715
- return np .all (new_array == 0 )
714
+ return np .all (np .mod (array [~ np .isnan (array )], 1 ) == 0 )
716
715
717
716
718
717
def logp (point , out_vars , vars , shared ): # pylint: disable=redefined-builtin
0 commit comments