Skip to content

Commit c88edea

Browse files
authored
fig bug and func argument (#61)
1 parent ec8b84d commit c88edea

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

pymc_bart/pgbart.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def __init__(
113113
shape=self.shape,
114114
)
115115
self.normal = NormalSampler(mu_std, self.shape)
116-
self.uniform = UniformSampler(0.33, 0.75, self.shape)
116+
self.uniform = UniformSampler(0, 1)
117+
self.uniform_kf = UniformSampler(0.33, 0.75, self.shape)
117118
self.prior_prob_leaf_node = compute_prior_probability(self.bart.alpha)
118119
self.ssv = SampleSplittingVariable(self.alpha_vec)
119120

@@ -253,7 +254,6 @@ def get_particle_tree(self, particles, normalized_weights):
253254
discrete_uniform_sampler(self.num_particles)
254255
]
255256
new_particle = particles[new_index - 2]
256-
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
257257
return new_particle, new_particle.tree
258258

259259
def systematic(self, normalized_weights):
@@ -265,7 +265,7 @@ def systematic(self, normalized_weights):
265265
Note: adapted from https://github.com/nchopin/particles
266266
"""
267267
lnw = len(normalized_weights)
268-
single_uniform = (self.uniform.random()[0] + np.arange(lnw)) / lnw
268+
single_uniform = (self.uniform.random() + np.arange(lnw)) / lnw
269269
return inverse_cdf(single_uniform, normalized_weights) + 2
270270

271271
def init_particles(self, tree_id: int) -> np.ndarray:
@@ -285,7 +285,7 @@ def init_particles(self, tree_id: int) -> np.ndarray:
285285

286286
for _ in self.indices:
287287
particles.append(
288-
ParticleTree(self.a_tree, self.uniform.random() if self.tune else p0.kfactor)
288+
ParticleTree(self.a_tree, self.uniform_kf.random() if self.tune else p0.kfactor)
289289
)
290290

291291
return np.array(particles)
@@ -575,7 +575,7 @@ def update(self):
575575
class UniformSampler:
576576
"""Cache samples from a uniform distribution."""
577577

578-
def __init__(self, lower_bound, upper_bound, shape):
578+
def __init__(self, lower_bound, upper_bound, shape=None):
579579
self.size = 1000
580580
self.upper_bound = upper_bound
581581
self.lower_bound = lower_bound
@@ -585,15 +585,21 @@ def __init__(self, lower_bound, upper_bound, shape):
585585
def random(self):
586586
if self.idx == self.size:
587587
self.update()
588-
pop = self.cache[:, self.idx]
588+
if self.shape is None:
589+
pop = self.cache[self.idx]
590+
else:
591+
pop = self.cache[:, self.idx]
589592
self.idx += 1
590593
return pop
591594

592595
def update(self):
593596
self.idx = 0
594-
self.cache = np.random.uniform(
595-
self.lower_bound, self.upper_bound, size=(self.shape, self.size)
596-
)
597+
if self.shape is None:
598+
self.cache = np.random.uniform(self.lower_bound, self.upper_bound, size=self.size)
599+
else:
600+
self.cache = np.random.uniform(
601+
self.lower_bound, self.upper_bound, size=(self.shape, self.size)
602+
)
597603

598604

599605
@njit
@@ -629,7 +635,7 @@ def inverse_cdf(single_uniform, normalized_weights):
629635

630636

631637
def logp(point, out_vars, vars, shared): # pylint: disable=redefined-builtin
632-
"""Compile Aesara function of the model and the input and output variables.
638+
"""Compile PyTensor function of the model and the input and output variables.
633639
634640
Parameters
635641
----------

pymc_bart/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _sample_posterior(all_trees, X, rng, size=None, excluded=None):
4747

4848
for ind, p in enumerate(pred):
4949
for tree in stacked_trees[idx[ind]]:
50-
p += np.array([tree.predict(x, excluded) for x in X])
50+
p += np.vstack([tree.predict(x, excluded) for x in X])
5151
pred.reshape((*size, shape, -1))
5252
return pred
5353

@@ -61,6 +61,7 @@ def plot_dependence(
6161
xs_values=None,
6262
var_idx=None,
6363
var_discrete=None,
64+
func=None,
6465
samples=50,
6566
instances=10,
6667
random_seed=None,
@@ -104,6 +105,8 @@ def plot_dependence(
104105
List of the indices of the covariate for which to compute the pdp or ice.
105106
var_discrete : list
106107
List of the indices of the covariate treated as discrete.
108+
func : function
109+
Arbitrary function to apply to the predictions. Defaults to the identity function.
107110
samples : int
108111
Number of posterior samples used in the predictions. Defaults to 50
109112
instances : int
@@ -228,6 +231,8 @@ def plot_dependence(
228231
y_mins.append(np.min(y_pred))
229232
new_y.append(np.array(y_pred).T)
230233

234+
if func is not None:
235+
new_y = func(new_y)
231236
shape = 1
232237
if new_y[0].ndim == 3:
233238
shape = new_y[0].shape[0]

0 commit comments

Comments
 (0)