Skip to content

Commit fdba010

Browse files
authored
Allow X to be a shared variable (#21)
* pps * pass shared variable * clean * fix tests * check Variable * avoid reshaping list of trees * add test * fix test * remove comments
1 parent ce302bc commit fdba010

File tree

5 files changed

+125
-38
lines changed

5 files changed

+125
-38
lines changed

pymc_bart/bart.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,20 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
from multiprocessing import Manager
1718
import aesara.tensor as at
1819
import numpy as np
1920

2021
from aeppl.logprob import _logprob
2122
from aesara.tensor.random.op import RandomVariable
23+
from aesara.tensor.var import Variable
24+
2225
from pandas import DataFrame, Series
2326

2427
from pymc.distributions.distribution import Distribution, _moment
2528

29+
from .utils import sample_posterior
30+
2631
__all__ = ["BART"]
2732

2833

@@ -34,16 +39,24 @@ class BARTRV(RandomVariable):
3439
ndims_params = [2, 1, 0, 0, 1]
3540
dtype = "floatX"
3641
_print_name = ("BART", "\\operatorname{BART}")
42+
all_trees = None
3743

3844
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
39-
return (self.X.shape[0],)
45+
if isinstance(self.X, Variable):
46+
shape = self.X.shape[0].eval()
47+
else:
48+
shape = self.X.shape[0]
49+
return (shape,)
4050

4151
@classmethod
42-
def rng_fn(cls, rng, X, Y, m, alpha, split_prior, size):
43-
if size is not None:
44-
return np.full((size[0], cls.Y.shape[0]), cls.Y.mean())
52+
def rng_fn(cls, rng=None, X=None, Y=None, m=None, alpha=None, split_prior=None, size=None):
53+
if not cls.all_trees:
54+
if size is not None:
55+
return np.full((size[0], cls.Y.shape[0]), cls.Y.mean())
56+
else:
57+
return np.full(cls.Y.shape[0], cls.Y.mean())
4558
else:
46-
return np.full(cls.Y.shape[0], cls.Y.mean())
59+
return sample_posterior(cls.all_trees, cls.X)
4760

4861

4962
bart = BARTRV()
@@ -69,7 +82,7 @@ class BART(Distribution):
6982
split_prior : array-like
7083
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
7184
1. Otherwise they will be normalized.
72-
Defaults to None, i.e. all covariates have the same prior probability to be selected.
85+
Defaults to 0, i.e. all covariates have the same prior probability to be selected.
7386
"""
7487

7588
def __new__(
@@ -82,17 +95,20 @@ def __new__(
8295
split_prior=None,
8396
**kwargs,
8497
):
98+
manager = Manager()
99+
cls.all_trees = manager.list()
85100

86101
X, Y = preprocess_xy(X, Y)
87102

88103
if split_prior is None:
89-
split_prior = np.ones(X.shape[1])
104+
split_prior = []
90105

91106
bart_op = type(
92107
f"BART_{name}",
93108
(BARTRV,),
94109
dict(
95110
name="BART",
111+
all_trees=cls.all_trees,
96112
inplace=False,
97113
initval=Y.mean(),
98114
X=X,
@@ -142,8 +158,10 @@ def preprocess_xy(X, Y):
142158
Y = Y.to_numpy()
143159
if isinstance(X, (Series, DataFrame)):
144160
X = X.to_numpy()
161+
145162
Y = Y.astype(float)
146163
X = X.astype(float)
164+
147165
return X, Y
148166

149167

pymc_bart/pgbart.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@
1717
from copy import deepcopy
1818
from numba import njit
1919

20-
import aesara
2120
import numpy as np
2221

2322
from aesara import function as aesara_function
23+
from aesara import config
24+
from aesara.tensor.var import Variable
25+
2426
from pymc.model import modelcontext
2527
from pymc.step_methods.arraystep import ArrayStepShared, Competence
2628
from pymc.aesaraf import inputvars, join_nonshared_inputs, make_shared_replacements
2729

30+
2831
from pymc_bart.bart import BARTRV
2932
from pymc_bart.tree import LeafNode, SplitNode, Tree
3033

@@ -53,7 +56,7 @@ class PGBART(ArrayStepShared):
5356
name = "pgbart"
5457
default_blocked = False
5558
generates_stats = True
56-
stats_dtypes = [{"variable_inclusion": object, "bart_trees": object}]
59+
stats_dtypes = [{"variable_inclusion": object}]
5760

5861
def __init__(
5962
self,
@@ -72,7 +75,11 @@ def __init__(
7275
value_bart = vars[0]
7376
self.bart = model.values_to_rvs[value_bart].owner.op
7477

75-
self.X = self.bart.X
78+
if isinstance(self.bart.X, Variable):
79+
self.X = self.bart.X.eval()
80+
else:
81+
self.X = self.bart.X
82+
7683
self.Y = self.bart.Y
7784
self.missing_data = np.any(np.isnan(self.X))
7885
self.m = self.bart.m
@@ -83,7 +90,11 @@ def __init__(
8390
else:
8491
self.shape = shape[0]
8592

86-
self.alpha_vec = self.bart.split_prior
93+
# self.alpha_vec = self.bart.split_prior
94+
if self.bart.split_prior:
95+
self.alpha_vec = self.bart.split_prior
96+
else:
97+
self.alpha_vec = np.ones(self.X.shape[1])
8798
self.init_mean = self.Y.mean()
8899
# if data is binary
89100
y_unique = np.unique(self.Y)
@@ -98,7 +109,7 @@ def __init__(
98109
self.available_predictors = list(range(self.num_variates))
99110

100111
self.sum_trees = np.full((self.shape, self.Y.shape[0]), self.init_mean).astype(
101-
aesara.config.floatX
112+
config.floatX
102113
)
103114
self.sum_trees_noi = self.sum_trees - (self.init_mean / self.m)
104115
self.a_tree = Tree(
@@ -200,7 +211,10 @@ def astep(self, _):
200211
for index in used_variates:
201212
variable_inclusion[index] += 1
202213

203-
stats = {"variable_inclusion": variable_inclusion, "bart_trees": self.all_trees}
214+
if not self.tune:
215+
self.bart.all_trees.append(self.all_trees)
216+
217+
stats = {"variable_inclusion": variable_inclusion}
204218
return self.sum_trees, [stats]
205219

206220
def normalize(self, particles):
@@ -261,7 +275,7 @@ def systematic(self, normalized_weights):
261275
Note: adapted from https://github.com/nchopin/particles
262276
"""
263277
lnw = len(normalized_weights)
264-
single_uniform = (self.uniform.random() + np.arange(lnw)) / lnw
278+
single_uniform = (self.uniform.random()[0] + np.arange(lnw)) / lnw
265279
return inverse_cdf(single_uniform, normalized_weights) + 2
266280

267281
def init_particles(self, tree_id: int) -> np.ndarray:

pymc_bart/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from copy import deepcopy
1818

19-
import aesara
19+
from aesara import config
2020
import numpy as np
2121

2222

@@ -59,7 +59,7 @@ def __init__(self, leaf_node_value, idx_data_points, num_observations, shape):
5959
0: LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points)
6060
}
6161
self.idx_leaf_nodes = [0]
62-
self.output = np.zeros((num_observations, shape)).astype(aesara.config.floatX).squeeze()
62+
self.output = np.zeros((num_observations, shape)).astype(config.floatX).squeeze()
6363

6464
def __getitem__(self, index):
6565
return self.get_node(index)

pymc_bart/utils.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,21 @@
44
import matplotlib.pyplot as plt
55
import numpy as np
66

7+
from aesara.tensor.var import Variable
78
from numpy.random import RandomState
89
from scipy.interpolate import griddata
910
from scipy.signal import savgol_filter
1011
from scipy.stats import pearsonr
1112

1213

13-
def predict(idata, rng, X, size=None, excluded=None):
14+
def predict(bartrv, rng, X, size=None, excluded=None):
1415
"""
1516
Generate samples from the BART-posterior.
1617
1718
Parameters
1819
----------
19-
idata : InferenceData
20-
InferenceData containing a collection of BART_trees in sample_stats group
20+
bartrv : BART Random Variable
21+
BART variable once the model that include it has been fitted.
2122
rng: NumPy random generator
2223
X : array-like
2324
A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
@@ -27,8 +28,10 @@ def predict(idata, rng, X, size=None, excluded=None):
2728
excluded : list
2829
indexes of the variables to exclude when computing predictions
2930
"""
30-
bart_trees = idata.sample_stats.bart_trees
31-
stacked_trees = bart_trees.stack(trees=["chain", "draw"])
31+
stacked_trees = bartrv.owner.op.all_trees
32+
if isinstance(X, Variable):
33+
X = X.eval()
34+
3235
if size is None:
3336
size = ()
3437
elif isinstance(size, int):
@@ -38,20 +41,49 @@ def predict(idata, rng, X, size=None, excluded=None):
3841
for s in size:
3942
flatten_size *= s
4043

41-
idx = rng.randint(len(stacked_trees.trees), size=flatten_size)
42-
shape = stacked_trees.isel(trees=0).values[0].predict(X[0]).size
44+
idx = rng.randint(len(stacked_trees), size=flatten_size)
45+
shape = stacked_trees[0][0].predict(X[0]).size
4346

4447
pred = np.zeros((flatten_size, X.shape[0], shape))
4548

4649
for ind, p in enumerate(pred):
47-
for tree in stacked_trees.isel(trees=idx[ind]).values:
50+
for tree in stacked_trees[idx[ind]]:
4851
p += np.array([tree.predict(x, excluded) for x in X])
4952
pred.reshape((*size, shape, -1))
5053
return pred
5154

5255

56+
def sample_posterior(all_trees, X):
57+
"""
58+
Generate samples from the BART-posterior.
59+
60+
Parameters
61+
----------
62+
all_trees : list
63+
List of all trees sampled from a posterior
64+
X : array-like
65+
A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
66+
out-of-sample predictions.
67+
m : int
68+
Number of trees
69+
"""
70+
stacked_trees = all_trees
71+
idx = np.random.randint(len(stacked_trees))
72+
if isinstance(X, Variable):
73+
X = X.eval()
74+
75+
shape = stacked_trees[0][0].predict(X[0]).size
76+
77+
pred = np.zeros((1, X.shape[0], shape))
78+
79+
for p in pred:
80+
for tree in stacked_trees[idx]:
81+
p += np.array([tree.predict(x) for x in X])
82+
return pred.squeeze()
83+
84+
5385
def plot_dependence(
54-
idata,
86+
bartrv,
5587
X,
5688
Y=None,
5789
kind="pdp",
@@ -79,8 +111,8 @@ def plot_dependence(
79111
80112
Parameters
81113
----------
82-
idata: InferenceData
83-
InferenceData containing a collection of BART_trees in sample_stats group
114+
bartrv : BART Random Variable
115+
BART variable once the model that include it has been fitted.
84116
X : array-like
85117
The covariate matrix.
86118
Y : array-like
@@ -149,6 +181,9 @@ def plot_dependence(
149181

150182
rng = RandomState(seed=random_seed)
151183

184+
if isinstance(X, Variable):
185+
X = X.eval()
186+
152187
if hasattr(X, "columns") and hasattr(X, "values"):
153188
x_names = list(X.columns)
154189
X = X.values
@@ -207,13 +242,13 @@ def plot_dependence(
207242
for x_i in new_x_i:
208243
new_X[:, indices_mi] = X[:, indices_mi]
209244
new_X[:, i] = x_i
210-
y_pred.append(np.mean(predict(idata, rng, X=new_X, size=samples), 1))
245+
y_pred.append(np.mean(predict(bartrv, rng, X=new_X, size=samples), 1))
211246
new_x_target.append(new_x_i)
212247
else:
213248
for instance in instances:
214249
new_X = X[idx_s]
215250
new_X[:, indices_mi] = X[:, indices_mi][instance]
216-
y_pred.append(np.mean(predict(idata, rng, X=new_X, size=samples), 0))
251+
y_pred.append(np.mean(predict(bartrv, rng, X=new_X, size=samples), 0))
217252
new_x_target.append(new_X[:, i])
218253
y_mins.append(np.min(y_pred))
219254
new_y.append(np.array(y_pred).T)
@@ -310,7 +345,7 @@ def plot_dependence(
310345

311346

312347
def plot_variable_importance(
313-
idata, X, labels=None, sort_vars=True, figsize=None, samples=100, random_seed=None
348+
idata, bartrv, X, labels=None, sort_vars=True, figsize=None, samples=100, random_seed=None
314349
):
315350
"""
316351
Estimates variable importance from the BART-posterior.
@@ -319,6 +354,8 @@ def plot_variable_importance(
319354
----------
320355
idata: InferenceData
321356
InferenceData containing a collection of BART_trees in sample_stats group
357+
bartrv : BART Random Variable
358+
BART variable once the model that include it has been fitted.
322359
X : array-like
323360
The covariate matrix.
324361
labels : list
@@ -365,12 +402,12 @@ def plot_variable_importance(
365402
axes[0].set_xlabel("covariables")
366403
axes[0].set_ylabel("importance")
367404

368-
predicted_all = predict(idata, rng, X=X, size=samples, excluded=None)
405+
predicted_all = predict(bartrv, rng, X=X, size=samples, excluded=None)
369406

370407
ev_mean = np.zeros(len(var_imp))
371408
ev_hdi = np.zeros((len(var_imp), 2))
372409
for idx, subset in enumerate(subsets):
373-
predicted_subset = predict(idata, rng, X=X, size=samples, excluded=subset)
410+
predicted_subset = predict(bartrv, rng, X=X, size=samples, excluded=subset)
374411
pearson = np.zeros(samples)
375412
for j in range(samples):
376413
pearson[j] = (

0 commit comments

Comments
 (0)