Skip to content

Commit c415074

Browse files
authored
add jitter to duplicated values for continuous splitting rule (#129)
* jit continuous rule * omit intermidiate step
1 parent d202b07 commit c415074

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

pymc_bart/pgbart.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,9 @@ def __init__(
155155
else:
156156
self.split_rules = [ContinuousSplitRule] * self.X.shape[1]
157157

158-
jittered = np.random.normal(self.X, self.X.std(axis=0) / 12)
159-
min_values = np.min(self.X, axis=0)
160-
max_values = np.max(self.X, axis=0)
161-
self.X = np.clip(jittered, min_values, max_values)
158+
for idx, rule in enumerate(self.split_rules):
159+
if rule is ContinuousSplitRule:
160+
self.X[:, idx] = jitter_duplicated(self.X[:, idx], np.std(self.X[:, idx]))
162161

163162
init_mean = self.bart.Y.mean()
164163
self.num_observations = self.X.shape[0]
@@ -693,6 +692,21 @@ def inverse_cdf(
693692
return new_indices
694693

695694

695+
@njit
696+
def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[np.float_]:
697+
"""
698+
Jitter duplicated values.
699+
"""
700+
seen = []
701+
for idx, num in enumerate(array):
702+
if num in seen:
703+
array[idx] = num + np.random.normal(0, std / 12)
704+
else:
705+
seen.append(num)
706+
707+
return array
708+
709+
696710
def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin
697711
"""Compile PyTensor function of the model and the input and output variables.
698712

pymc_bart/split_rules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def divide(available_splitting_values, split_value):
7979
class SubsetSplitRule(SplitRule):
8080
"""
8181
Choose a random subset of the categorical values and branch on belonging to that set.
82-
This is the approach taken by Sameer K. Deshpande.
82+
This is the approach taken by Sameer K. Deshpande.
8383
flexBART: Flexible Bayesian regression trees with categorical predictors. arXiv,
8484
`link <https://arxiv.org/abs/2211.04459>`__
8585
"""

0 commit comments

Comments
 (0)