Skip to content

Commit 6aad349

Browse files
authored
Refactor of the classes BaseNode, SplitNode and LeafNode to avoid using inheritance into just one class Node, so we could use JitClasses of numba in the near future; (#35)
Change the behaviour of _traverse_tree to return the leaf node value or the mean of leaf nodes values;
1 parent 4dff795 commit 6aad349

File tree

4 files changed

+85
-85
lines changed

4 files changed

+85
-85
lines changed

pymc_bart/pgbart.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929

3030

3131
from pymc_bart.bart import BARTRV
32-
from pymc_bart.tree import LeafNode, SplitNode, Tree
33-
32+
from pymc_bart.tree import Tree, Node
3433

3534
_log = logging.getLogger("pymc")
3635

@@ -412,7 +411,7 @@ def rvs(self):
412411

413412
def compute_prior_probability(alpha):
414413
"""
415-
Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)).
414+
Calculate the probability of the node being a leaf node (1 - p(being split node)).
416415
417416
Taken from equation 19 in [Rockova2018].
418417
@@ -480,17 +479,17 @@ def grow_tree(
480479
shape,
481480
)
482481

483-
new_node = LeafNode(
482+
new_node = Node.new_leaf_node(
484483
index=current_node_children[idx],
485484
value=node_value,
486485
idx_data_points=idx_data_point,
487486
)
488487
new_nodes.append(new_node)
489488

490-
new_split_node = SplitNode(
489+
new_split_node = Node.new_split_node(
491490
index=index_leaf_node,
492-
idx_split_variable=selected_predictor,
493491
split_value=split_value,
492+
idx_split_variable=selected_predictor,
494493
)
495494

496495
# update tree nodes and indexes

pymc_bart/tree.py

Lines changed: 49 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class Tree:
3232
A dictionary that represents the nodes stored in breadth-first order, based in the array
3333
method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays).
3434
The dictionary's keys are integers that represent the nodes position.
35-
The dictionary's values are objects of type SplitNode or LeafNode that represent the nodes
35+
The dictionary's values are objects of type Node that represent the split and leaf nodes
3636
of the tree itself.
3737
idx_leaf_nodes : list
3838
List with the index of the leaf nodes of the tree.
@@ -56,7 +56,7 @@ class Tree:
5656

5757
def __init__(self, leaf_node_value, idx_data_points, num_observations, shape):
5858
self.tree_structure = {
59-
0: LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points)
59+
0: Node.new_leaf_node(0, value=leaf_node_value, idx_data_points=idx_data_points)
6060
}
6161
self.idx_leaf_nodes = [0]
6262
self.output = np.zeros((num_observations, shape)).astype(config.floatX).squeeze()
@@ -70,12 +70,12 @@ def __setitem__(self, index, node):
7070
def copy(self):
7171
return deepcopy(self)
7272

73-
def get_node(self, index):
73+
def get_node(self, index) -> "Node":
7474
return self.tree_structure[index]
7575

7676
def set_node(self, index, node):
7777
self.tree_structure[index] = node
78-
if isinstance(node, LeafNode):
78+
if node.is_leaf_node():
7979
self.idx_leaf_nodes.append(index)
8080

8181
def delete_leaf_node(self, index):
@@ -89,15 +89,13 @@ def trim(self):
8989
for k in a_tree.tree_structure.keys():
9090
current_node = a_tree[k]
9191
del current_node.depth
92-
if isinstance(current_node, LeafNode):
92+
if current_node.is_leaf_node():
9393
del current_node.idx_data_points
9494
return a_tree
9595

9696
def get_split_variables(self):
9797
return [
98-
node.idx_split_variable
99-
for node in self.tree_structure.values()
100-
if isinstance(node, SplitNode)
98+
node.idx_split_variable for node in self.tree_structure.values() if node.is_split_node()
10199
]
102100

103101
def _predict(self):
@@ -115,6 +113,8 @@ def predict(self, x, excluded=None):
115113
----------
116114
x : numpy array
117115
Unobserved point
116+
excluded: list
117+
Indexes of the variables to exclude when computing predictions
118118
119119
Returns
120120
-------
@@ -123,12 +123,7 @@ def predict(self, x, excluded=None):
123123
"""
124124
if excluded is None:
125125
excluded = []
126-
node = self._traverse_tree(x, 0, excluded)
127-
if isinstance(node, LeafNode):
128-
leaf_value = node.value
129-
else:
130-
leaf_value = node
131-
return leaf_value
126+
return self._traverse_tree(x, 0, excluded)
132127

133128
def _traverse_tree(self, x, node_index, excluded):
134129
"""
@@ -141,22 +136,22 @@ def _traverse_tree(self, x, node_index, excluded):
141136
142137
Returns
143138
-------
144-
LeafNode or mean of leaf node values
139+
Leaf node value or mean of leaf node values
145140
"""
146141
current_node = self.get_node(node_index)
147-
if isinstance(current_node, SplitNode):
148-
if current_node.idx_split_variable in excluded:
149-
leaf_values = []
150-
self._traverse_leaf_values(leaf_values, node_index)
151-
return np.mean(leaf_values, 0)
152-
153-
if x[current_node.idx_split_variable] <= current_node.split_value:
154-
left_child = current_node.get_idx_left_child()
155-
current_node = self._traverse_tree(x, left_child, excluded)
156-
else:
157-
right_child = current_node.get_idx_right_child()
158-
current_node = self._traverse_tree(x, right_child, excluded)
159-
return current_node
142+
if current_node.is_leaf_node():
143+
return current_node.value
144+
if current_node.idx_split_variable in excluded:
145+
leaf_values = []
146+
self._traverse_leaf_values(leaf_values, node_index)
147+
return np.mean(leaf_values, 0)
148+
149+
if x[current_node.idx_split_variable] <= current_node.value:
150+
left_child = current_node.get_idx_left_child()
151+
return self._traverse_tree(x, left_child, excluded)
152+
else:
153+
right_child = current_node.get_idx_right_child()
154+
return self._traverse_tree(x, right_child, excluded)
160155

161156
def _traverse_leaf_values(self, leaf_values, node_index):
162157
"""
@@ -170,47 +165,43 @@ def _traverse_leaf_values(self, leaf_values, node_index):
170165
-------
171166
List of leaf node values
172167
"""
173-
current_node = self.get_node(node_index)
174-
if isinstance(current_node, SplitNode):
175-
left_child = current_node.get_idx_left_child()
176-
self._traverse_leaf_values(leaf_values, left_child)
177-
right_child = current_node.get_idx_right_child()
178-
self._traverse_leaf_values(leaf_values, right_child)
168+
node = self.get_node(node_index)
169+
if node.is_leaf_node():
170+
leaf_values.append(node.value)
179171
else:
180-
leaf_values.append(current_node.value)
172+
self._traverse_leaf_values(leaf_values, node.get_idx_left_child())
173+
self._traverse_leaf_values(leaf_values, node.get_idx_right_child())
181174

182175

183-
class BaseNode:
184-
__slots__ = "index", "depth"
176+
class Node:
177+
__slots__ = "index", "depth", "value", "idx_split_variable", "idx_data_points"
185178

186-
def __init__(self, index):
179+
def __init__(self, index: int, value=-1, idx_data_points=None, idx_split_variable=-1):
187180
self.index = index
188181
self.depth = int(math.floor(math.log(index + 1, 2)))
182+
self.value = value
183+
self.idx_data_points = idx_data_points
184+
self.idx_split_variable = idx_split_variable
189185

190-
def get_idx_parent_node(self):
186+
@classmethod
187+
def new_leaf_node(cls, index: int, value, idx_data_points) -> "Node":
188+
return cls(index, value=value, idx_data_points=idx_data_points)
189+
190+
@classmethod
191+
def new_split_node(cls, index: int, split_value, idx_split_variable) -> "Node":
192+
return cls(index, value=split_value, idx_split_variable=idx_split_variable)
193+
194+
def get_idx_parent_node(self) -> int:
191195
return (self.index - 1) // 2
192196

193-
def get_idx_left_child(self):
197+
def get_idx_left_child(self) -> int:
194198
return self.index * 2 + 1
195199

196-
def get_idx_right_child(self):
200+
def get_idx_right_child(self) -> int:
197201
return self.get_idx_left_child() + 1
198202

203+
def is_split_node(self) -> bool:
204+
return self.idx_split_variable >= 0
199205

200-
class SplitNode(BaseNode):
201-
__slots__ = "idx_split_variable", "split_value"
202-
203-
def __init__(self, index, idx_split_variable, split_value):
204-
super().__init__(index)
205-
206-
self.idx_split_variable = idx_split_variable
207-
self.split_value = split_value
208-
209-
210-
class LeafNode(BaseNode):
211-
__slots__ = "value", "idx_data_points"
212-
213-
def __init__(self, index, value, idx_data_points):
214-
super().__init__(index)
215-
self.value = value
216-
self.idx_data_points = idx_data_points
206+
def is_leaf_node(self) -> bool:
207+
return not self.is_split_node()

tests/test_bart.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,6 @@
88
import pymc_bart as pmb
99

1010

11-
def test_split_node():
12-
split_node = pmb.tree.SplitNode(index=5, idx_split_variable=2, split_value=3.0)
13-
assert split_node.index == 5
14-
assert split_node.idx_split_variable == 2
15-
assert split_node.split_value == 3.0
16-
assert split_node.depth == 2
17-
assert split_node.get_idx_parent_node() == 2
18-
assert split_node.get_idx_left_child() == 11
19-
assert split_node.get_idx_right_child() == 12
20-
21-
22-
def test_leaf_node():
23-
leaf_node = pmb.tree.LeafNode(index=5, value=3.14, idx_data_points=[1, 2, 3])
24-
assert leaf_node.index == 5
25-
assert np.array_equal(leaf_node.idx_data_points, [1, 2, 3])
26-
assert leaf_node.value == 3.14
27-
assert leaf_node.get_idx_parent_node() == 2
28-
assert leaf_node.get_idx_left_child() == 11
29-
assert leaf_node.get_idx_right_child() == 12
30-
31-
3211
def test_bart_vi():
3312
X = np.random.normal(0, 1, size=(250, 3))
3413
Y = np.random.normal(0, 1, size=250)

tests/test_tree.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import numpy as np
2+
3+
from pymc_bart.tree import Node
4+
5+
6+
def test_split_node():
7+
split_node = Node.new_split_node(index=5, idx_split_variable=2, split_value=3.0)
8+
assert split_node.index == 5
9+
assert split_node.depth == 2
10+
assert split_node.value == 3.0
11+
assert split_node.idx_split_variable == 2
12+
assert split_node.idx_data_points is None
13+
assert split_node.get_idx_parent_node() == 2
14+
assert split_node.get_idx_left_child() == 11
15+
assert split_node.get_idx_right_child() == 12
16+
assert split_node.is_split_node() is True
17+
assert split_node.is_leaf_node() is False
18+
19+
20+
def test_leaf_node():
21+
leaf_node = Node.new_leaf_node(index=5, value=3.14, idx_data_points=[1, 2, 3])
22+
assert leaf_node.index == 5
23+
assert leaf_node.depth == 2
24+
assert leaf_node.value == 3.14
25+
assert leaf_node.idx_split_variable == -1
26+
assert np.array_equal(leaf_node.idx_data_points, [1, 2, 3])
27+
assert leaf_node.get_idx_parent_node() == 2
28+
assert leaf_node.get_idx_left_child() == 11
29+
assert leaf_node.get_idx_right_child() == 12
30+
assert leaf_node.is_split_node() is False
31+
assert leaf_node.is_leaf_node() is True

0 commit comments

Comments
 (0)