Skip to content

Commit ff22efe

Browse files
authored
avoid creating variables when it is posible; (#54)
change the get_split_variables() to be a generator; instead of creating Particle and then put the value of the kfactor we refactor constructor to have the default value in 0.75 an the parameter; grow_tree now returns the new leaf idx; in grow_tree, change the if else sentence for early return, also, instead of creating new node split node and delete the leaf value there is a new method grow_leaf_node; in the draw_leaf_value we simplify the expresions; in _travearse_tree instead of having two returns we create a variable next_node; adding @njit in the methods get_new_idx_data_points and draw_leaf_value;
1 parent b786dae commit ff22efe

File tree

3 files changed

+69
-93
lines changed

3 files changed

+69
-93
lines changed

pymc_bart/pgbart.py

Lines changed: 57 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -187,18 +187,18 @@ def astep(self, _):
187187

188188
_, normalized_weights = self.normalize(particles)
189189
# Get the new tree and update
190-
new_particle, new_tree = self.get_particle_tree(particles, normalized_weights)
191-
self.all_particles[tree_id] = new_particle
190+
self.all_particles[tree_id], new_tree = self.get_particle_tree(
191+
particles, normalized_weights
192+
)
192193
self.sum_trees = self.sum_trees_noi + new_tree._predict()
193194
self.all_trees[tree_id] = new_tree.trim()
194-
used_variates = new_tree.get_split_variables()
195195

196196
if self.tune:
197197
self.ssv = SampleSplittingVariable(self.alpha_vec)
198-
for index in used_variates:
198+
for index in new_tree.get_split_variables():
199199
self.alpha_vec[index] += 1
200200
else:
201-
for index in used_variates:
201+
for index in new_tree.get_split_variables():
202202
variable_inclusion[index] += 1
203203

204204
if not self.tune:
@@ -284,12 +284,9 @@ def init_particles(self, tree_id: int) -> np.ndarray:
284284
particles = [p0, p1]
285285

286286
for _ in self.indices:
287-
pt = ParticleTree(self.a_tree)
288-
if self.tune:
289-
pt.kfactor = self.uniform.random()
290-
else:
291-
pt.kfactor = p0.kfactor
292-
particles.append(pt)
287+
particles.append(
288+
ParticleTree(self.a_tree, self.uniform.random() if self.tune else p0.kfactor)
289+
)
293290

294291
return np.array(particles)
295292

@@ -305,10 +302,10 @@ def update_weight(self, particle, old=False):
305302
)
306303
if old:
307304
particle.log_weight = new_likelihood
308-
particle.old_likelihood_logp = new_likelihood
309305
else:
310306
particle.log_weight += new_likelihood - particle.old_likelihood_logp
311-
particle.old_likelihood_logp = new_likelihood
307+
308+
particle.old_likelihood_logp = new_likelihood
312309

313310
@staticmethod
314311
def competence(var, has_grad):
@@ -324,21 +321,19 @@ class ParticleTree:
324321

325322
__slots__ = "tree", "expansion_nodes", "log_weight", "old_likelihood_logp", "kfactor"
326323

327-
def __init__(self, tree):
324+
def __init__(self, tree, kfactor=0.75):
328325
self.tree = tree.copy()
329326
self.expansion_nodes = [0]
330327
self.log_weight = 0
331328
self.old_likelihood_logp = 0
332-
self.kfactor = 0.75
329+
self.kfactor = kfactor
333330

334331
def copy(self):
335332
p = ParticleTree(self.tree)
336-
p.expansion_nodes, p.log_weight, p.old_likelihood_logp, p.kfactor = (
337-
self.expansion_nodes.copy(),
338-
self.log_weight,
339-
self.old_likelihood_logp,
340-
self.kfactor,
341-
)
333+
p.expansion_nodes = self.expansion_nodes.copy()
334+
p.log_weight = self.log_weight
335+
p.old_likelihood_logp = self.old_likelihood_logp
336+
p.kfactor = self.kfactor
342337
return p
343338

344339
def sample_tree(
@@ -360,7 +355,7 @@ def sample_tree(
360355
prob_leaf = prior_prob_leaf_node[get_depth(index_leaf_node)]
361356

362357
if prob_leaf < np.random.random():
363-
index_selected_predictor = grow_tree(
358+
idx_new_nodes = grow_tree(
364359
self.tree,
365360
index_leaf_node,
366361
ssv,
@@ -373,9 +368,8 @@ def sample_tree(
373368
self.kfactor,
374369
shape,
375370
)
376-
if index_selected_predictor is not None:
377-
new_indexes = self.tree.idx_leaf_nodes[-2:]
378-
self.expansion_nodes.extend(new_indexes)
371+
if idx_new_nodes is not None:
372+
self.expansion_nodes.extend(idx_new_nodes)
379373
tree_grew = True
380374

381375
return tree_grew
@@ -389,8 +383,7 @@ def sample_leafs(self, sum_trees, m, normal, shape):
389383
node_value = draw_leaf_value(
390384
sum_trees[:, idx_data_points],
391385
m,
392-
normal,
393-
self.kfactor,
386+
normal.random() * self.kfactor,
394387
shape,
395388
)
396389
leaf.value = node_value
@@ -463,55 +456,43 @@ def grow_tree(
463456
split_value = get_split_value(available_splitting_values, idx_data_points, missing_data)
464457

465458
if split_value is None:
466-
index_selected_predictor = None
467-
else:
468-
new_idx_data_points = get_new_idx_data_points(
469-
split_value, idx_data_points, selected_predictor, X
470-
)
471-
current_node_children = (
472-
current_node.get_idx_left_child(),
473-
current_node.get_idx_right_child(),
459+
return None
460+
new_idx_data_points = get_new_idx_data_points(
461+
available_splitting_values, split_value, idx_data_points
462+
)
463+
current_node_children = (
464+
current_node.get_idx_left_child(),
465+
current_node.get_idx_right_child(),
466+
)
467+
468+
new_nodes = []
469+
for idx in range(2):
470+
idx_data_point = new_idx_data_points[idx]
471+
node_value = draw_leaf_value(
472+
sum_trees[:, idx_data_point],
473+
m,
474+
normal.random() * kfactor,
475+
shape,
474476
)
475477

476-
new_nodes = []
477-
for idx in range(2):
478-
idx_data_point = new_idx_data_points[idx]
479-
node_value = draw_leaf_value(
480-
sum_trees[:, idx_data_point],
481-
m,
482-
normal,
483-
kfactor,
484-
shape,
485-
)
486-
487-
new_node = Node.new_leaf_node(
488-
index=current_node_children[idx],
489-
value=node_value,
490-
idx_data_points=idx_data_point,
491-
)
492-
new_nodes.append(new_node)
493-
494-
new_split_node = Node.new_split_node(
495-
index=index_leaf_node,
496-
split_value=split_value,
497-
idx_split_variable=selected_predictor,
478+
new_node = Node.new_leaf_node(
479+
index=current_node_children[idx],
480+
value=node_value,
481+
idx_data_points=idx_data_point,
498482
)
483+
new_nodes.append(new_node)
499484

500-
# update tree nodes and indexes
501-
tree.delete_leaf_node(index_leaf_node)
502-
tree.set_node(index_leaf_node, new_split_node)
503-
tree.set_node(new_nodes[0].index, new_nodes[0])
504-
tree.set_node(new_nodes[1].index, new_nodes[1])
505-
506-
return index_selected_predictor
485+
tree.grow_leaf_node(current_node, selected_predictor, split_value, index_leaf_node)
486+
tree.set_node(new_nodes[0].index, new_nodes[0])
487+
tree.set_node(new_nodes[1].index, new_nodes[1])
507488

489+
return [new_nodes[0].index, new_nodes[1].index]
508490

509-
def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X):
510-
left_idx = X[idx_data_points, selected_predictor] <= split_value
511-
left_node_idx_data_points = idx_data_points[left_idx]
512-
right_node_idx_data_points = idx_data_points[~left_idx]
513491

514-
return left_node_idx_data_points, right_node_idx_data_points
492+
@njit
493+
def get_new_idx_data_points(available_splitting_values, split_value, idx_data_points):
494+
split_idx = available_splitting_values <= split_value
495+
return idx_data_points[split_idx], idx_data_points[~split_idx]
515496

516497

517498
def get_split_value(available_splitting_values, idx_data_points, missing_data):
@@ -529,19 +510,18 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data):
529510
return split_value
530511

531512

532-
def draw_leaf_value(y_mu_pred, m, normal, kfactor, shape):
513+
@njit
514+
def draw_leaf_value(y_mu_pred, m, norm, shape):
533515
"""Draw Gaussian distributed leaf values."""
534516
if y_mu_pred.size == 0:
535517
return np.zeros(shape)
518+
519+
if y_mu_pred.size == 1:
520+
mu_mean = np.full(shape, y_mu_pred.item() / m)
536521
else:
537-
norm = normal.random() * kfactor
538-
if y_mu_pred.size == 1:
539-
mu_mean = np.full(shape, y_mu_pred.item() / m)
540-
else:
541-
mu_mean = fast_mean(y_mu_pred) / m
522+
mu_mean = fast_mean(y_mu_pred) / m
542523

543-
draw = norm + mu_mean
544-
return draw
524+
return norm + mu_mean
545525

546526

547527
@njit

pymc_bart/tree.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,11 @@ def set_node(self, index, node):
8888
if node.is_leaf_node():
8989
self.idx_leaf_nodes.append(index)
9090

91-
def delete_leaf_node(self, index):
92-
self.idx_leaf_nodes.remove(index)
93-
del self.tree_structure[index]
91+
def grow_leaf_node(self, current_node, selected_predictor, split_value, index_leaf_node):
92+
current_node.value = split_value
93+
current_node.idx_split_variable = selected_predictor
94+
current_node.idx_data_points = None
95+
self.idx_leaf_nodes.remove(index_leaf_node)
9496

9597
def trim(self):
9698
tree = {
@@ -100,9 +102,9 @@ def trim(self):
100102
return Tree(tree, None, None)
101103

102104
def get_split_variables(self):
103-
return [
104-
node.idx_split_variable for node in self.tree_structure.values() if node.is_split_node()
105-
]
105+
for node in self.tree_structure.values():
106+
if node.is_split_node():
107+
yield node.idx_split_variable
106108

107109
def _predict(self):
108110
output = self.output
@@ -153,11 +155,10 @@ def _traverse_tree(self, x, node_index, excluded):
153155
return np.mean(leaf_values, 0)
154156

155157
if x[current_node.idx_split_variable] <= current_node.value:
156-
left_child = current_node.get_idx_left_child()
157-
return self._traverse_tree(x, left_child, excluded)
158+
next_node = current_node.get_idx_left_child()
158159
else:
159-
right_child = current_node.get_idx_right_child()
160-
return self._traverse_tree(x, right_child, excluded)
160+
next_node = current_node.get_idx_right_child()
161+
return self._traverse_tree(x, next_node, excluded)
161162

162163
def _traverse_leaf_values(self, leaf_values, node_index):
163164
"""
@@ -196,14 +197,11 @@ def new_leaf_node(cls, index: int, value, idx_data_points) -> "Node":
196197
def new_split_node(cls, index: int, split_value, idx_split_variable) -> "Node":
197198
return cls(index, value=split_value, idx_split_variable=idx_split_variable)
198199

199-
def get_idx_parent_node(self) -> int:
200-
return (self.index - 1) // 2
201-
202200
def get_idx_left_child(self) -> int:
203201
return self.index * 2 + 1
204202

205203
def get_idx_right_child(self) -> int:
206-
return self.get_idx_left_child() + 1
204+
return self.index * 2 + 2
207205

208206
def is_split_node(self) -> bool:
209207
return self.idx_split_variable >= 0

tests/test_tree.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ def test_split_node():
1010
assert split_node.value == 3.0
1111
assert split_node.idx_split_variable == 2
1212
assert split_node.idx_data_points is None
13-
assert split_node.get_idx_parent_node() == 2
1413
assert split_node.get_idx_left_child() == 11
1514
assert split_node.get_idx_right_child() == 12
1615
assert split_node.is_split_node() is True
@@ -24,7 +23,6 @@ def test_leaf_node():
2423
assert leaf_node.value == 3.14
2524
assert leaf_node.idx_split_variable == -1
2625
assert np.array_equal(leaf_node.idx_data_points, [1, 2, 3])
27-
assert leaf_node.get_idx_parent_node() == 2
2826
assert leaf_node.get_idx_left_child() == 11
2927
assert leaf_node.get_idx_right_child() == 12
3028
assert leaf_node.is_split_node() is False

0 commit comments

Comments
 (0)