Skip to content

Commit 7c02015

Browse files
authored
fix split_prior bug (#115)
* fix split_prior bug * fix type * fix type
1 parent d318b38 commit 7c02015

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

pymc_bart/bart.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ class BARTRV(RandomVariable):
4343
_print_name: Tuple[str, str] = ("BART", "\\operatorname{BART}")
4444
all_trees = List[List[List[Tree]]]
4545

46-
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
46+
def _supp_shape_from_params(
47+
self, dist_params, rep_param_idx=1, param_shapes=None
48+
): # pylint: disable=arguments-renamed
4749
return dist_params[0].shape[:1]
4850

4951
@classmethod
@@ -126,7 +128,7 @@ def __new__(
126128
alpha: float = 0.95,
127129
beta: float = 2.0,
128130
response: str = "constant",
129-
split_prior: Optional[List[float]] = None,
131+
split_prior: Optional[npt.NDArray[np.float_]] = None,
130132
split_rules: Optional[List[SplitRule]] = None,
131133
separate_trees: Optional[bool] = False,
132134
**kwargs,
@@ -141,8 +143,7 @@ def __new__(
141143

142144
X, Y = preprocess_xy(X, Y)
143145

144-
if split_prior is None:
145-
split_prior = []
146+
split_prior = np.array([]) if split_prior is None else np.asarray(split_prior)
146147

147148
bart_op = type(
148149
f"BART_{name}",

pymc_bart/pgbart.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,10 @@ def __init__(
145145
self.trees_shape = self.shape if self.bart.separate_trees else 1
146146
self.leaves_shape = self.shape if not self.bart.separate_trees else 1
147147

148-
if self.bart.split_prior:
149-
self.alpha_vec = self.bart.split_prior
148+
if self.bart.split_prior.size == 0:
149+
self.alpha_vec = np.ones(self.X.shape[1])
150150
else:
151-
self.alpha_vec = np.ones(self.X.shape[1], dtype=np.int32)
151+
self.alpha_vec = self.bart.split_prior
152152

153153
if self.bart.split_rules:
154154
self.split_rules = self.bart.split_rules

0 commit comments

Comments
 (0)