Skip to content

Commit 30e5b0f

Browse files
Improvements over Tree implementation (#40)
* Improvements over Tree implementation: - Remove variable depth on the Node class to be a method; - Remove variable leaf_node_value on the Tree class as is not need and was only writing and never read of it; * Update pymc_bart/tree.py Co-authored-by: Osvaldo A Martin <[email protected]> Co-authored-by: Osvaldo A Martin <[email protected]>
1 parent 8311007 commit 30e5b0f

File tree

3 files changed

+13
-14
lines changed

3 files changed

+13
-14
lines changed

pymc_bart/pgbart.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@
2727
from pymc.step_methods.arraystep import ArrayStepShared, Competence
2828
from pymc.pytensorf import inputvars, join_nonshared_inputs, make_shared_replacements
2929

30-
3130
from pymc_bart.bart import BARTRV
32-
from pymc_bart.tree import Tree, Node
31+
from pymc_bart.tree import Tree, Node, get_depth
3332

3433
_log = logging.getLogger("pymc")
3534

@@ -138,9 +137,7 @@ def __init__(
138137
self.likelihood_logp = logp(initial_values, [model.datalogp], vars, shared)
139138
self.all_particles = []
140139
for _ in range(self.m):
141-
self.a_tree.leaf_node_value = init_mean / self.m
142-
p = ParticleTree(self.a_tree)
143-
self.all_particles.append(p)
140+
self.all_particles.append(ParticleTree(self.a_tree))
144141
self.all_trees = np.array([p.tree for p in self.all_particles])
145142
super().__init__(vars, shared)
146143

@@ -352,7 +349,7 @@ def sample_tree(
352349
if self.expansion_nodes:
353350
index_leaf_node = self.expansion_nodes.pop(0)
354351
# Probability that this node will remain a leaf node
355-
prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth]
352+
prob_leaf = prior_prob_leaf_node[get_depth(index_leaf_node)]
356353

357354
if prob_leaf < np.random.random():
358355
index_selected_predictor = grow_tree(

pymc_bart/tree.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import math
1616

1717
from copy import deepcopy
18+
from functools import lru_cache
1819

1920
from pytensor import config
2021
import numpy as np
@@ -41,7 +42,6 @@ class Tree:
4142
4243
Parameters
4344
----------
44-
leaf_node_value : int or float
4545
idx_data_points : array of integers
4646
num_observations : integer
4747
shape : int
@@ -51,7 +51,6 @@ class Tree:
5151
"tree_structure",
5252
"idx_leaf_nodes",
5353
"output",
54-
"leaf_node_value",
5554
)
5655

5756
def __init__(self, leaf_node_value, idx_data_points, num_observations, shape):
@@ -88,7 +87,6 @@ def trim(self):
8887
del a_tree.idx_leaf_nodes
8988
for k in a_tree.tree_structure.keys():
9089
current_node = a_tree[k]
91-
del current_node.depth
9290
if current_node.is_leaf_node():
9391
del current_node.idx_data_points
9492
return a_tree
@@ -174,11 +172,10 @@ def _traverse_leaf_values(self, leaf_values, node_index):
174172

175173

176174
class Node:
177-
__slots__ = "index", "depth", "value", "idx_split_variable", "idx_data_points"
175+
__slots__ = "index", "value", "idx_split_variable", "idx_data_points"
178176

179177
def __init__(self, index: int, value=-1, idx_data_points=None, idx_split_variable=-1):
180178
self.index = index
181-
self.depth = int(math.floor(math.log(index + 1, 2)))
182179
self.value = value
183180
self.idx_data_points = idx_data_points
184181
self.idx_split_variable = idx_split_variable
@@ -205,3 +202,8 @@ def is_split_node(self) -> bool:
205202

206203
def is_leaf_node(self) -> bool:
207204
return not self.is_split_node()
205+
206+
207+
@lru_cache
208+
def get_depth(index: int) -> int:
209+
return math.floor(math.log2(index + 1))

tests/test_tree.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import numpy as np
22

3-
from pymc_bart.tree import Node
3+
from pymc_bart.tree import Node, get_depth
44

55

66
def test_split_node():
77
split_node = Node.new_split_node(index=5, idx_split_variable=2, split_value=3.0)
88
assert split_node.index == 5
9-
assert split_node.depth == 2
9+
assert get_depth(split_node.index) == 2
1010
assert split_node.value == 3.0
1111
assert split_node.idx_split_variable == 2
1212
assert split_node.idx_data_points is None
@@ -20,7 +20,7 @@ def test_split_node():
2020
def test_leaf_node():
2121
leaf_node = Node.new_leaf_node(index=5, value=3.14, idx_data_points=[1, 2, 3])
2222
assert leaf_node.index == 5
23-
assert leaf_node.depth == 2
23+
assert get_depth(leaf_node.index) == 2
2424
assert leaf_node.value == 3.14
2525
assert leaf_node.idx_split_variable == -1
2626
assert np.array_equal(leaf_node.idx_data_points, [1, 2, 3])

0 commit comments

Comments
 (0)