Skip to content

Commit 2925162

Browse files
authored
jitter array of whole numbers (#133)
1 parent 83f2409 commit 2925162

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

pymc_bart/pgbart.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -697,16 +697,24 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[
697697
"""
698698
Jitter duplicated values.
699699
"""
700-
seen = []
701-
for idx, num in enumerate(array):
702-
if num in seen:
703-
array[idx] = num + np.random.normal(0, std / 12)
704-
else:
705-
seen.append(num)
700+
if are_whole_number(array):
701+
seen = []
702+
for idx, num in enumerate(array):
703+
if num in seen:
704+
array[idx] = num + np.random.normal(0, std / 12)
705+
else:
706+
seen.append(num)
706707

707708
return array
708709

709710

711+
@njit
712+
def are_whole_number(array: npt.NDArray[np.float_]) -> np.bool_:
713+
"""Check if all values in array are whole numbers"""
714+
new_array = np.mod(array, 1)
715+
return np.all(new_array == 0)
716+
717+
710718
def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin
711719
"""Compile PyTensor function of the model and the input and output variables.
712720

0 commit comments

Comments
 (0)