@@ -155,10 +155,9 @@ def __init__(
155
155
else :
156
156
self .split_rules = [ContinuousSplitRule ] * self .X .shape [1 ]
157
157
158
- jittered = np .random .normal (self .X , self .X .std (axis = 0 ) / 12 )
159
- min_values = np .min (self .X , axis = 0 )
160
- max_values = np .max (self .X , axis = 0 )
161
- self .X = np .clip (jittered , min_values , max_values )
158
+ for idx , rule in enumerate (self .split_rules ):
159
+ if rule is ContinuousSplitRule :
160
+ self .X [:, idx ] = jitter_duplicated (self .X [:, idx ], np .std (self .X [:, idx ]))
162
161
163
162
init_mean = self .bart .Y .mean ()
164
163
self .num_observations = self .X .shape [0 ]
@@ -693,6 +692,21 @@ def inverse_cdf(
693
692
return new_indices
694
693
695
694
695
+ @njit
696
+ def jitter_duplicated (array : npt .NDArray [np .float_ ], std : float ) -> npt .NDArray [np .float_ ]:
697
+ """
698
+ Jitter duplicated values.
699
+ """
700
+ seen = []
701
+ for idx , num in enumerate (array ):
702
+ if num in seen :
703
+ array [idx ] = num + np .random .normal (0 , std / 12 )
704
+ else :
705
+ seen .append (num )
706
+
707
+ return array
708
+
709
+
696
710
def logp (point , out_vars , vars , shared ): # pylint: disable=redefined-builtin
697
711
"""Compile PyTensor function of the model and the input and output variables.
698
712
0 commit comments