Skip to content

Commit 0f0e361

Browse files
authored
use slots (#17)
1 parent 4665786 commit 0f0e361

File tree

2 files changed

+30
-27
lines changed

2 files changed

+30
-27
lines changed

pymc_bart/pgbart.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,10 @@ def __init__(
101101
aesara.config.floatX
102102
)
103103
self.sum_trees_noi = self.sum_trees - (self.init_mean / self.m)
104-
self.a_tree = Tree.init_tree(
104+
self.a_tree = Tree(
105105
leaf_node_value=self.init_mean / self.m,
106106
idx_data_points=np.arange(self.num_observations, dtype="int32"),
107+
num_observations=self.num_observations,
107108
shape=self.shape,
108109
)
109110
self.normal = NormalSampler(mu_std, self.shape)
@@ -297,6 +298,8 @@ def competence(var, has_grad):
297298
class ParticleTree:
298299
"""Particle tree."""
299300

301+
__slots__ = "tree", "expansion_nodes", "log_weight", "old_likelihood_logp", "kfactor"
302+
300303
def __init__(self, tree):
301304
self.tree = tree.copy() # keeps the tree that we care at the moment
302305
self.expansion_nodes = [0]

pymc_bart/tree.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,31 @@ class Tree:
3636
of the tree itself.
3737
idx_leaf_nodes : list
3838
List with the index of the leaf nodes of the tree.
39-
num_observations : int
40-
Number of observations used to fit BART.
41-
m : int
42-
Number of trees
39+
output: array
40+
Array of shape number of observations, shape
4341
4442
Parameters
4543
----------
46-
num_observations : int, optional
44+
leaf_node_value : int or float
45+
idx_data_points : array of integers
46+
num_observations : integer
47+
shape : int
4748
"""
4849

49-
def __init__(self, num_observations=0, shape=1):
50-
self.tree_structure = {}
51-
self.idx_leaf_nodes = []
50+
__slots__ = (
51+
"tree_structure",
52+
"idx_leaf_nodes",
53+
"output",
54+
"leaf_node_value",
55+
"idx_data_points",
56+
"shape",
57+
)
58+
59+
def __init__(self, leaf_node_value, idx_data_points, num_observations, shape):
60+
self.tree_structure = {
61+
0: LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points)
62+
}
63+
self.idx_leaf_nodes = [0]
5264
self.output = np.zeros((num_observations, shape)).astype(aesara.config.floatX).squeeze()
5365

5466
def __getitem__(self, index):
@@ -169,26 +181,10 @@ def _traverse_leaf_values(self, leaf_values, node_index):
169181
else:
170182
leaf_values.append(current_node.value)
171183

172-
@staticmethod
173-
def init_tree(leaf_node_value, idx_data_points, shape):
174-
"""
175-
Initialize tree.
176-
177-
Parameters
178-
----------
179-
leaf_node_value
180-
idx_data_points
181-
182-
Returns
183-
-------
184-
tree
185-
"""
186-
new_tree = Tree(len(idx_data_points), shape)
187-
new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points)
188-
return new_tree
189-
190184

191185
class BaseNode:
186+
__slots__ = "index", "depth"
187+
192188
def __init__(self, index):
193189
self.index = index
194190
self.depth = int(math.floor(math.log(index + 1, 2)))
@@ -204,6 +200,8 @@ def get_idx_right_child(self):
204200

205201

206202
class SplitNode(BaseNode):
203+
__slots__ = "index", "idx_split_variable", "split_value"
204+
207205
def __init__(self, index, idx_split_variable, split_value):
208206
super().__init__(index)
209207

@@ -212,6 +210,8 @@ def __init__(self, index, idx_split_variable, split_value):
212210

213211

214212
class LeafNode(BaseNode):
213+
__slots__ = "index", "value", "idx_data_points"
214+
215215
def __init__(self, index, value, idx_data_points):
216216
super().__init__(index)
217217
self.value = value

0 commit comments

Comments
 (0)