Skip to content

Commit df307a5

Browse files
authored
clean and refactor (#18)
1 parent 0f0e361 commit df307a5

File tree

3 files changed

+28
-24
lines changed

3 files changed

+28
-24
lines changed

pymc_bart/bart.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ class BARTRV(RandomVariable):
3434
ndims_params = [2, 1, 0, 0, 1]
3535
dtype = "floatX"
3636
_print_name = ("BART", "\\operatorname{BART}")
37-
all_trees = None
3837

3938
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
4039
return (self.X.shape[0],)

pymc_bart/pgbart.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __init__(
123123
else:
124124
self.batch = (batch, batch)
125125

126+
self.num_particles = num_particles
126127
self.log_num_particles = np.log(num_particles)
127128
self.indices = list(range(2, num_particles))
128129
self.len_indices = len(self.indices)
@@ -185,14 +186,10 @@ def astep(self, _):
185186

186187
_, normalized_weights = self.normalize(particles)
187188
# Get the new tree and update
188-
new_particle = np.random.choice(particles, p=normalized_weights)
189-
new_tree = new_particle.tree
190-
191-
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
189+
new_particle, new_tree = self.get_particle_tree(particles, normalized_weights)
192190
self.all_particles[tree_id] = new_particle
193191
self.sum_trees = self.sum_trees_noi + new_tree._predict()
194192
self.all_trees[tree_id] = new_tree.trim()
195-
196193
used_variates = new_tree.get_split_variables()
197194

198195
if self.tune:
@@ -230,7 +227,7 @@ def resample(self, particles, normalized_weights):
230227
231228
Ensure particles are copied only if needed.
232229
"""
233-
new_indices = systematic(normalized_weights)
230+
new_indices = self.systematic(normalized_weights)
234231
seen = []
235232
new_particles = []
236233
for idx in new_indices:
@@ -244,6 +241,29 @@ def resample(self, particles, normalized_weights):
244241

245242
return particles
246243

244+
def get_particle_tree(self, particles, normalized_weights):
245+
"""
246+
Sample a new particle, new tree and update log_weight
247+
"""
248+
new_index = self.systematic(normalized_weights)[
249+
discrete_uniform_sampler(self.num_particles)
250+
]
251+
new_particle = particles[new_index - 2]
252+
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
253+
return new_particle, new_particle.tree
254+
255+
def systematic(self, normalized_weights):
256+
"""
257+
Systematic resampling.
258+
259+
Return indices in the range 2, ..., len(normalized_weights)+2
260+
261+
Note: adapted from https://github.com/nchopin/particles
262+
"""
263+
lnw = len(normalized_weights)
264+
single_uniform = (self.uniform.random() + np.arange(lnw)) / lnw
265+
return inverse_cdf(single_uniform, normalized_weights) + 2
266+
247267
def init_particles(self, tree_id: int) -> np.ndarray:
248268
"""Initialize particles."""
249269
p0 = self.all_particles[tree_id]
@@ -584,19 +604,6 @@ def update(self):
584604
)
585605

586606

587-
def systematic(normalized_weights):
588-
"""
589-
Systematic resampling.
590-
591-
Return indices in the range 2, ..., len(normalized_weights)+2
592-
593-
Note: adapted from https://github.com/nchopin/particles
594-
"""
595-
lnw = len(normalized_weights)
596-
single_uniform = (np.random.rand(1) + np.arange(lnw)) / lnw
597-
return inverse_cdf(single_uniform, normalized_weights) + 2
598-
599-
600607
@njit
601608
def inverse_cdf(single_uniform, normalized_weights):
602609
"""

pymc_bart/tree.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ class Tree:
5252
"idx_leaf_nodes",
5353
"output",
5454
"leaf_node_value",
55-
"idx_data_points",
56-
"shape",
5755
)
5856

5957
def __init__(self, leaf_node_value, idx_data_points, num_observations, shape):
@@ -200,7 +198,7 @@ def get_idx_right_child(self):
200198

201199

202200
class SplitNode(BaseNode):
203-
__slots__ = "index", "idx_split_variable", "split_value"
201+
__slots__ = "idx_split_variable", "split_value"
204202

205203
def __init__(self, index, idx_split_variable, split_value):
206204
super().__init__(index)
@@ -210,7 +208,7 @@ def __init__(self, index, idx_split_variable, split_value):
210208

211209

212210
class LeafNode(BaseNode):
213-
__slots__ = "index", "value", "idx_data_points"
211+
__slots__ = "value", "idx_data_points"
214212

215213
def __init__(self, index, value, idx_data_points):
216214
super().__init__(index)

0 commit comments

Comments
 (0)