Skip to content

Commit f32bc75

Browse files
authored
Clean utils (#23)
* clean utils * remove comments * remove unused import
1 parent fdba010 commit f32bc75

File tree

4 files changed

+33
-54
lines changed

4 files changed

+33
-54
lines changed

pymc_bart/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from pymc_bart.bart import BART
1717
from pymc_bart.pgbart import PGBART
18-
from pymc_bart.utils import plot_dependence, plot_variable_importance, predict
18+
from pymc_bart.utils import plot_dependence, plot_variable_importance
1919

2020
__all__ = ["BART", "PGBART"]
2121
__version__ = "0.1.0"

pymc_bart/bart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from pymc.distributions.distribution import Distribution, _moment
2828

29-
from .utils import sample_posterior
29+
from .utils import _sample_posterior
3030

3131
__all__ = ["BART"]
3232

@@ -56,7 +56,7 @@ def rng_fn(cls, rng=None, X=None, Y=None, m=None, alpha=None, split_prior=None,
5656
else:
5757
return np.full(cls.Y.shape[0], cls.Y.mean())
5858
else:
59-
return sample_posterior(cls.all_trees, cls.X)
59+
return _sample_posterior(cls.all_trees, cls.X, rng=rng).squeeze()
6060

6161

6262
bart = BARTRV()

pymc_bart/utils.py

Lines changed: 24 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,29 @@
55
import numpy as np
66

77
from aesara.tensor.var import Variable
8-
from numpy.random import RandomState
98
from scipy.interpolate import griddata
109
from scipy.signal import savgol_filter
1110
from scipy.stats import pearsonr
1211

1312

14-
def predict(bartrv, rng, X, size=None, excluded=None):
13+
def _sample_posterior(all_trees, X, rng, size=None, excluded=None):
1514
"""
1615
Generate samples from the BART-posterior.
1716
1817
Parameters
1918
----------
20-
bartrv : BART Random Variable
21-
BART variable once the model that include it has been fitted.
22-
rng: NumPy random generator
19+
all_trees : list
20+
List of all trees sampled from a posterior
2321
X : array-like
2422
A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
2523
out-of-sample predictions.
24+
rng : NumPy RandomGenerator
2625
size : int or tuple
2726
Number of samples.
2827
excluded : list
29-
indexes of the variables to exclude when computing predictions
28+
Indexes of the variables to exclude when computing predictions
3029
"""
31-
stacked_trees = bartrv.owner.op.all_trees
30+
stacked_trees = all_trees
3231
if isinstance(X, Variable):
3332
X = X.eval()
3433

@@ -41,7 +40,7 @@ def predict(bartrv, rng, X, size=None, excluded=None):
4140
for s in size:
4241
flatten_size *= s
4342

44-
idx = rng.randint(len(stacked_trees), size=flatten_size)
43+
idx = rng.integers(0, len(stacked_trees), size=flatten_size)
4544
shape = stacked_trees[0][0].predict(X[0]).size
4645

4746
pred = np.zeros((flatten_size, X.shape[0], shape))
@@ -53,35 +52,6 @@ def predict(bartrv, rng, X, size=None, excluded=None):
5352
return pred
5453

5554

56-
def sample_posterior(all_trees, X):
57-
"""
58-
Generate samples from the BART-posterior.
59-
60-
Parameters
61-
----------
62-
all_trees : list
63-
List of all trees sampled from a posterior
64-
X : array-like
65-
A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
66-
out-of-sample predictions.
67-
m : int
68-
Number of trees
69-
"""
70-
stacked_trees = all_trees
71-
idx = np.random.randint(len(stacked_trees))
72-
if isinstance(X, Variable):
73-
X = X.eval()
74-
75-
shape = stacked_trees[0][0].predict(X[0]).size
76-
77-
pred = np.zeros((1, X.shape[0], shape))
78-
79-
for p in pred:
80-
for tree in stacked_trees[idx]:
81-
p += np.array([tree.predict(x) for x in X])
82-
return pred.squeeze()
83-
84-
8555
def plot_dependence(
8656
bartrv,
8757
X,
@@ -179,8 +149,6 @@ def plot_dependence(
179149
Available option are 'insample', 'linear' or 'quantiles'"""
180150
)
181151

182-
rng = RandomState(seed=random_seed)
183-
184152
if isinstance(X, Variable):
185153
X = X.eval()
186154

@@ -195,6 +163,8 @@ def plot_dependence(
195163
else:
196164
y_label = "Predicted Y"
197165

166+
rng = np.random.default_rng(random_seed)
167+
198168
num_covariates = X.shape[1]
199169

200170
indices = list(range(num_covariates))
@@ -216,14 +186,15 @@ def plot_dependence(
216186
xs_values = [0.05, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.95]
217187

218188
if kind == "ice":
219-
instances = np.random.choice(range(X.shape[0]), replace=False, size=instances)
189+
instances = rng.choice(range(X.shape[0]), replace=False, size=instances)
220190

221191
new_y = []
222192
new_x_target = []
223193
y_mins = []
224194

225195
new_X = np.zeros_like(X)
226196
idx_s = list(range(X.shape[0]))
197+
all_trees = bartrv.owner.op.all_trees
227198
for i in var_idx:
228199
indices_mi = indices[:]
229200
indices_mi.pop(i)
@@ -242,13 +213,17 @@ def plot_dependence(
242213
for x_i in new_x_i:
243214
new_X[:, indices_mi] = X[:, indices_mi]
244215
new_X[:, i] = x_i
245-
y_pred.append(np.mean(predict(bartrv, rng, X=new_X, size=samples), 1))
216+
y_pred.append(
217+
np.mean(_sample_posterior(all_trees, X=new_X, rng=rng, size=samples), 1)
218+
)
246219
new_x_target.append(new_x_i)
247220
else:
248221
for instance in instances:
249222
new_X = X[idx_s]
250223
new_X[:, indices_mi] = X[:, indices_mi][instance]
251-
y_pred.append(np.mean(predict(bartrv, rng, X=new_X, size=samples), 0))
224+
y_pred.append(
225+
np.mean(_sample_posterior(all_trees, X=new_X, rng=rng, size=samples), 0)
226+
)
252227
new_x_target.append(new_X[:, i])
253228
y_mins.append(np.min(y_pred))
254229
new_y.append(np.array(y_pred).T)
@@ -328,7 +303,7 @@ def plot_dependence(
328303
nxi,
329304
nyi,
330305
smooth=smooth,
331-
fill_kwargs={"alpha": alpha},
306+
fill_kwargs={"alpha": alpha, "color": color},
332307
ax=ax,
333308
)
334309
ax.plot(nxi[idx], nyi[idx].mean(0), color=color)
@@ -374,7 +349,6 @@ def plot_variable_importance(
374349
idxs: indexes of the covariates from higher to lower relative importance
375350
axes: matplotlib axes
376351
"""
377-
rng = RandomState(seed=random_seed)
378352
_, axes = plt.subplots(2, 1, figsize=figsize)
379353

380354
if hasattr(X, "columns") and hasattr(X, "values"):
@@ -387,6 +361,8 @@ def plot_variable_importance(
387361
else:
388362
labels = np.array(labels)
389363

364+
rng = np.random.default_rng(random_seed)
365+
390366
ticks = np.arange(len(var_imp), dtype=int)
391367
idxs = np.argsort(var_imp)
392368
subsets = [idxs[:-i] for i in range(1, len(idxs))]
@@ -402,12 +378,14 @@ def plot_variable_importance(
402378
axes[0].set_xlabel("covariables")
403379
axes[0].set_ylabel("importance")
404380

405-
predicted_all = predict(bartrv, rng, X=X, size=samples, excluded=None)
381+
all_trees = bartrv.owner.op.all_trees
382+
383+
predicted_all = _sample_posterior(all_trees, X=X, rng=rng, size=samples, excluded=None)
406384

407385
ev_mean = np.zeros(len(var_imp))
408386
ev_hdi = np.zeros((len(var_imp), 2))
409387
for idx, subset in enumerate(subsets):
410-
predicted_subset = predict(bartrv, rng, X=X, size=samples, excluded=subset)
388+
predicted_subset = _sample_posterior(all_trees, X=X, rng=rng, size=samples, excluded=subset)
411389
pearson = np.zeros(samples)
412390
for j in range(samples):
413391
pearson[j] = (

tests/test_bart.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,12 @@ class TestUtils:
9191
y = pm.Normal("y", mu, sigma, observed=Y)
9292
idata = pm.sample(random_seed=3415)
9393

94-
def test_predict(self):
95-
rng = RandomState(12345)
96-
pred_all = pmb.predict(self.mu, rng, X=self.X, size=2)
97-
rng = RandomState(12345)
98-
pred_first = pmb.predict(self.mu, rng, X=self.X[:10])
94+
def test_sample_posterior(self):
95+
all_trees = self.mu.owner.op.all_trees
96+
rng = np.random.default_rng(3)
97+
pred_all = pmb.utils._sample_posterior(all_trees, X=self.X, rng=rng, size=2)
98+
rng = np.random.default_rng(3)
99+
pred_first = pmb.utils._sample_posterior(all_trees, X=self.X[:10], rng=rng)
99100

100101
assert_almost_equal(pred_first[0], pred_all[0, :10], decimal=4)
101102
assert pred_all.shape == (2, 50, 1)

0 commit comments

Comments
 (0)