Skip to content

Commit 7fb9c39

Browse files
authored
fig bug with nans (#136)
1 parent 5de28ff commit 7fb9c39

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

pymc_bart/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626

2727
__all__ = ["BART", "PGBART"]
28-
__version__ = "0.5.6"
28+
__version__ = "0.5.7"
2929

3030

3131
pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]

pymc_bart/pgbart.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def __init__(
157157

158158
for idx, rule in enumerate(self.split_rules):
159159
if rule is ContinuousSplitRule:
160-
self.X[:, idx] = jitter_duplicated(self.X[:, idx], np.std(self.X[:, idx]))
160+
self.X[:, idx] = jitter_duplicated(self.X[:, idx], np.nanstd(self.X[:, idx]))
161161

162162
init_mean = self.bart.Y.mean()
163163
self.num_observations = self.X.shape[0]
@@ -700,7 +700,7 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[
700700
if are_whole_number(array):
701701
seen = []
702702
for idx, num in enumerate(array):
703-
if num in seen:
703+
if num in seen and not np.isnan(num):
704704
array[idx] = num + np.random.normal(0, std / 12)
705705
else:
706706
seen.append(num)
@@ -711,8 +711,7 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[
711711
@njit
712712
def are_whole_number(array: npt.NDArray[np.float_]) -> np.bool_:
713713
"""Check if all values in array are whole numbers"""
714-
new_array = np.mod(array, 1)
715-
return np.all(new_array == 0)
714+
return np.all(np.mod(array[~np.isnan(array)], 1) == 0)
716715

717716

718717
def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin

0 commit comments

Comments
 (0)