Skip to content

Commit 842a696

Browse files
authored
BART with non-gaussian likelihoods (#4675)
* allow unbounded likelihoods, add inv_link and small refactor * add test * fix test * set jitter optional, update realease notes
1 parent 3447619 commit 842a696

File tree

5 files changed

+208
-42
lines changed

5 files changed

+208
-42
lines changed

RELEASE-NOTES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
+ A deprecation warning from the `semver` package we use for checking backend compatibility was dealt with (see [#4547](https://github.com/pymc-devs/pymc3/pull/4547)).
66
+ `theano.printing.pydotprint` is now hotfixed upon import (see [#4594](https://github.com/pymc-devs/pymc3/pull/4594)).
77

8+
### New Features
9+
+ BART with non-gaussian likelihoods (see [#4675](https://github.com/pymc-devs/pymc3/pull/4675)).
10+
811
## PyMC3 3.11.2 (14 March 2021)
912

1013
### New Features

pymc3/distributions/bart.py

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,24 @@
2323

2424

2525
class BaseBART(NoDistribution):
26-
def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs):
27-
26+
def __init__(
27+
self,
28+
X,
29+
Y,
30+
m=200,
31+
alpha=0.25,
32+
split_prior=None,
33+
scale=None,
34+
inv_link=None,
35+
jitter=False,
36+
*args,
37+
**kwargs,
38+
):
39+
40+
self.jitter = jitter
2841
self.X, self.Y, self.missing_data = self.preprocess_XY(X, Y)
2942

30-
super().__init__(shape=X.shape[0], dtype="float64", testval=0, *args, **kwargs)
43+
super().__init__(shape=X.shape[0], dtype="float64", testval=self.Y.mean(), *args, **kwargs)
3144

3245
if self.X.ndim != 2:
3346
raise ValueError("The design matrix X must have two dimensions")
@@ -48,13 +61,24 @@ def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs):
4861
"The value for the alpha parameter for the tree structure "
4962
"must be in the interval (0, 1)"
5063
)
64+
self.m = m
65+
self.alpha = alpha
66+
self.y_std = Y.std()
67+
68+
if scale is None:
69+
self.leaf_scale = NormalSampler(sigma=None)
70+
elif isinstance(scale, (int, float)):
71+
self.leaf_scale = NormalSampler(sigma=Y.std() / self.m ** scale)
72+
73+
if inv_link is None:
74+
self.inv_link = lambda x: x
75+
else:
76+
self.inv_link = inv_link
5177

5278
self.num_observations = X.shape[0]
5379
self.num_variates = X.shape[1]
5480
self.available_predictors = list(range(self.num_variates))
5581
self.ssv = SampleSplittingVariable(split_prior, self.num_variates)
56-
self.m = m
57-
self.alpha = alpha
5882
self.trees = self.init_list_of_trees()
5983
self.all_trees = []
6084
self.mean = fast_mean()
@@ -66,7 +90,9 @@ def preprocess_XY(self, X, Y):
6690
if isinstance(X, (Series, DataFrame)):
6791
X = X.to_numpy()
6892
missing_data = np.any(np.isnan(X))
69-
X = np.random.normal(X, np.std(X, 0) / 100)
93+
if self.jitter:
94+
X = np.random.normal(X, np.nanstd(X, 0) / 100000)
95+
Y = Y.astype(float)
7096
return X, Y, missing_data
7197

7298
def init_list_of_trees(self):
@@ -155,32 +181,27 @@ def get_new_idx_data_points(self, current_split_node, idx_data_points):
155181

156182
def get_residuals(self):
157183
"""Compute the residuals."""
158-
R_j = self.Y - self.sum_trees_output
159-
return R_j
184+
R_j = self.Y - self.inv_link(self.sum_trees_output)
160185

161-
def get_residuals_loo(self, tree):
162-
"""Compute the residuals without leaving the passed tree out."""
163-
R_j = self.Y - (self.sum_trees_output - tree.predict_output(self.num_observations))
164186
return R_j
165187

166188
def draw_leaf_value(self, idx_data_points):
167-
""" Draw the residual mean."""
189+
"""Draw the residual mean."""
168190
R_j = self.get_residuals()[idx_data_points]
169-
draw = self.mean(R_j)
191+
draw = self.mean(R_j) + self.leaf_scale.random()
170192
return draw
171193

172194
def predict(self, X_new):
173195
"""Compute out of sample predictions evaluated at X_new"""
174196
trees = self.all_trees
175197
num_observations = X_new.shape[0]
176198
pred = np.zeros((len(trees), num_observations))
177-
np.random.randint(len(trees))
178199
for draw, trees_to_sum in enumerate(trees):
179200
new_Y = np.zeros(num_observations)
180201
for tree in trees_to_sum:
181202
new_Y += [tree.predict_out_of_sample(x) for x in X_new]
182203
pred[draw] = new_Y
183-
return pred
204+
return self.inv_link(pred)
184205

185206

186207
def compute_prior_probability(alpha):
@@ -257,6 +278,24 @@ def rvs(self):
257278
return i
258279

259280

281+
class NormalSampler:
282+
def __init__(self, sigma):
283+
self.size = 5000
284+
self.cache = []
285+
self.sigma = sigma
286+
287+
def random(self):
288+
if self.sigma is None:
289+
return 0
290+
else:
291+
if not self.cache:
292+
self.update()
293+
return self.cache.pop()
294+
295+
def update(self):
296+
self.cache = np.random.normal(loc=0.0, scale=self.sigma, size=self.size).tolist()
297+
298+
260299
class BART(BaseBART):
261300
"""
262301
BART distribution.
@@ -278,10 +317,23 @@ class BART(BaseBART):
278317
Each element of split_prior should be in the [0, 1] interval and the elements should sum
279318
to 1. Otherwise they will be normalized.
280319
Defaults to None, all variable have the same a prior probability
320+
scale : float
321+
Controls the variance of the proposed leaf value. The leaf values are computed as a
322+
Gaussian with mean equal to the conditional residual mean and variance proportional to
323+
the variance of the response variable, and inversely proportional to the number of trees
324+
and the scale parameter. Defaults to None, i.e the variance is 0.
325+
inv_link : numpy function
326+
Inverse link function defaults to None, i.e. the identity function.
327+
jitter : bool
328+
Whether to jitter the X values or not. Defaults to False. When values of X are repeated,
329+
jittering X has the effect of increasing the number of effective spliting variables,
330+
otherwise it does not have any effect.
281331
"""
282332

283-
def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None):
284-
super().__init__(X, Y, m, alpha, split_prior)
333+
def __init__(
334+
self, X, Y, m=200, alpha=0.25, split_prior=None, scale=None, inv_link=None, jitter=False
335+
):
336+
super().__init__(X, Y, m, alpha, split_prior, scale, inv_link)
285337

286338
def _str_repr(self, name=None, dist=None, formatting="plain"):
287339
if dist is None:

pymc3/distributions/tree.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,13 @@ class Tree:
4545
tree_id : int, optional
4646
"""
4747

48-
def __init__(self, tree_id=0):
48+
def __init__(self, tree_id=0, num_observations=0):
4949
self.tree_structure = {}
5050
self.num_nodes = 0
5151
self.idx_leaf_nodes = []
5252
self.idx_prunable_split_nodes = []
5353
self.tree_id = tree_id
54+
self.num_observations = num_observations
5455

5556
def __getitem__(self, index):
5657
return self.get_node(index)
@@ -77,11 +78,12 @@ def delete_node(self, index):
7778
del self.tree_structure[index]
7879
self.num_nodes -= 1
7980

80-
def predict_output(self, num_observations):
81-
output = np.zeros(num_observations)
81+
def predict_output(self):
82+
output = np.zeros(self.num_observations)
8283
for node_index in self.idx_leaf_nodes:
8384
current_node = self.get_node(node_index)
8485
output[current_node.idx_data_points] = current_node.value
86+
8587
return output
8688

8789
def predict_out_of_sample(self, x):
@@ -163,7 +165,7 @@ def init_tree(tree_id, leaf_node_value, idx_data_points):
163165
-------
164166
165167
"""
166-
new_tree = Tree(tree_id)
168+
new_tree = Tree(tree_id, len(idx_data_points))
167169
new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points)
168170
return new_tree
169171

pymc3/step_methods/pgbart.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,13 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
8686

8787
def astep(self, _):
8888
bart = self.bart
89+
inv_link = bart.inv_link
8990
num_observations = bart.num_observations
9091
variable_inclusion = np.zeros(bart.num_variates, dtype="int")
9192

9293
# For the tunning phase we restrict max_stages to a low number, otherwise it is almost sure
9394
# we will reach max_stages given that our first set of m trees is not good at all.
94-
# Can set max_stages as a function of the number of variables/dimensions?
95+
# Can set max_stages as a function of the number of variables/dimensions? XXX
9596
if self.tune:
9697
max_stages = 5
9798
else:
@@ -105,10 +106,11 @@ def astep(self, _):
105106
break
106107
self.idx += 1
107108
tree = bart.trees[idx]
108-
R_j = bart.get_residuals_loo(tree)
109+
old_prediction = tree.predict_output()
110+
bart.sum_trees_output -= old_prediction
109111
# Generate an initial set of SMC particles
110112
# at the end of the algorithm we return one of these particles as the new tree
111-
particles = self.init_particles(tree.tree_id, R_j, num_observations)
113+
particles = self.init_particles(tree.tree_id, num_observations, inv_link)
112114

113115
for t in range(1, max_stages):
114116
# Get old particle at stage t
@@ -119,13 +121,12 @@ def astep(self, _):
119121
# Update weights. Since the prior is used as the proposal,the weights
120122
# are updated additively as the ratio of the new and old log_likelihoods
121123
for p_idx, p in enumerate(particles):
122-
new_likelihood = self.likelihood_logp(p.tree.predict_output(num_observations))
124+
new_likelihood = self.likelihood_logp(inv_link(p.tree.predict_output()))
123125
p.log_weight += new_likelihood - p.old_likelihood_logp
124126
p.old_likelihood_logp = new_likelihood
125127

126128
# Normalize weights
127129
W, normalized_weights = self.normalize(particles)
128-
129130
# Resample all but first particle
130131
re_n_w = normalized_weights[1:] / normalized_weights[1:].sum()
131132
new_indices = np.random.choice(self.indices, size=len(self.indices), p=re_n_w)
@@ -148,8 +149,8 @@ def astep(self, _):
148149
new_tree = np.random.choice(particles, p=normalized_weights)
149150
self.old_trees_particles_list[tree.tree_id] = new_tree
150151
bart.trees[idx] = new_tree.tree
151-
new_prediction = new_tree.tree.predict_output(num_observations)
152-
bart.sum_trees_output = bart.Y - R_j + new_prediction
152+
new_prediction = new_tree.tree.predict_output()
153+
bart.sum_trees_output += new_prediction
153154

154155
if not self.tune:
155156
self.iter += 1
@@ -161,8 +162,7 @@ def astep(self, _):
161162
variable_inclusion[index] += 1
162163

163164
stats = {"variable_inclusion": variable_inclusion}
164-
165-
return bart.sum_trees_output, [stats]
165+
return inv_link(bart.sum_trees_output), [stats]
166166

167167
@staticmethod
168168
def competence(var, has_grad):
@@ -194,31 +194,26 @@ def get_old_tree_particle(self, tree_id, t):
194194
old_tree_particle.set_particle_to_step(t)
195195
return old_tree_particle
196196

197-
def init_particles(self, tree_id, R_j, num_observations):
197+
def init_particles(self, tree_id, num_observations, inv_link):
198198
"""
199199
Initialize particles
200200
"""
201201
# The first particle is from the tree we are trying to replace
202202
prev_tree = self.get_old_tree_particle(tree_id, 0)
203-
likelihood = self.likelihood_logp(prev_tree.tree.predict_output(num_observations))
203+
likelihood = self.likelihood_logp(inv_link(prev_tree.tree.predict_output()))
204204
prev_tree.old_likelihood_logp = likelihood
205205
prev_tree.log_weight = likelihood - self.log_num_particles
206206
particles = [prev_tree]
207207

208208
# The rest of the particles are identically initialized
209-
initial_value_leaf_nodes = R_j.mean()
210209
initial_idx_data_points_leaf_nodes = np.arange(num_observations, dtype="int32")
211210
new_tree = Tree.init_tree(
212211
tree_id=tree_id,
213-
leaf_node_value=initial_value_leaf_nodes,
212+
leaf_node_value=0,
214213
idx_data_points=initial_idx_data_points_leaf_nodes,
215214
)
216-
likelihood_logp = self.likelihood_logp(new_tree.predict_output(num_observations))
217-
log_weight = likelihood_logp - self.log_num_particles
218215
for i in range(1, self.num_particles):
219-
particles.append(
220-
ParticleTree(new_tree, self.bart.prior_prob_leaf_node, log_weight, likelihood_logp)
221-
)
216+
particles.append(ParticleTree(new_tree, self.bart.prior_prob_leaf_node, 0, 0))
222217

223218
return np.array(particles)
224219

@@ -237,10 +232,10 @@ class ParticleTree:
237232

238233
def __init__(self, tree, prior_prob_leaf_node, log_weight=0, likelihood=0):
239234
self.tree = tree.copy() # keeps the tree that we care at the moment
240-
self.expansion_nodes = tree.idx_leaf_nodes.copy() # This should be the array [0]
235+
self.expansion_nodes = [0]
241236
self.tree_history = [self.tree]
242237
self.expansion_nodes_history = [self.expansion_nodes]
243-
self.log_weight = 0
238+
self.log_weight = log_weight
244239
self.prior_prob_leaf_node = prior_prob_leaf_node
245240
self.old_likelihood_logp = likelihood
246241
self.used_variates = []
@@ -253,7 +248,8 @@ def sample_tree_sequential(self, bart):
253248

254249
if prob_leaf < np.random.random():
255250
grow_successful, index_selected_predictor = bart.grow_tree(
256-
self.tree, index_leaf_node
251+
self.tree,
252+
index_leaf_node,
257253
)
258254
if grow_successful:
259255
# Add new leaf nodes indexes

0 commit comments

Comments
 (0)