@@ -36,19 +36,31 @@ class Tree:
36
36
of the tree itself.
37
37
idx_leaf_nodes : list
38
38
List with the index of the leaf nodes of the tree.
39
- num_observations : int
40
- Number of observations used to fit BART.
41
- m : int
42
- Number of trees
39
+ output: array
40
+ Array of shape number of observations, shape
43
41
44
42
Parameters
45
43
----------
46
- num_observations : int, optional
44
+ leaf_node_value : int or float
45
+ idx_data_points : array of integers
46
+ num_observations : integer
47
+ shape : int
47
48
"""
48
49
49
- def __init__ (self , num_observations = 0 , shape = 1 ):
50
- self .tree_structure = {}
51
- self .idx_leaf_nodes = []
50
+ __slots__ = (
51
+ "tree_structure" ,
52
+ "idx_leaf_nodes" ,
53
+ "output" ,
54
+ "leaf_node_value" ,
55
+ "idx_data_points" ,
56
+ "shape" ,
57
+ )
58
+
59
+ def __init__ (self , leaf_node_value , idx_data_points , num_observations , shape ):
60
+ self .tree_structure = {
61
+ 0 : LeafNode (index = 0 , value = leaf_node_value , idx_data_points = idx_data_points )
62
+ }
63
+ self .idx_leaf_nodes = [0 ]
52
64
self .output = np .zeros ((num_observations , shape )).astype (aesara .config .floatX ).squeeze ()
53
65
54
66
def __getitem__ (self , index ):
@@ -169,26 +181,10 @@ def _traverse_leaf_values(self, leaf_values, node_index):
169
181
else :
170
182
leaf_values .append (current_node .value )
171
183
172
- @staticmethod
173
- def init_tree (leaf_node_value , idx_data_points , shape ):
174
- """
175
- Initialize tree.
176
-
177
- Parameters
178
- ----------
179
- leaf_node_value
180
- idx_data_points
181
-
182
- Returns
183
- -------
184
- tree
185
- """
186
- new_tree = Tree (len (idx_data_points ), shape )
187
- new_tree [0 ] = LeafNode (index = 0 , value = leaf_node_value , idx_data_points = idx_data_points )
188
- return new_tree
189
-
190
184
191
185
class BaseNode :
186
+ __slots__ = "index" , "depth"
187
+
192
188
def __init__ (self , index ):
193
189
self .index = index
194
190
self .depth = int (math .floor (math .log (index + 1 , 2 )))
@@ -204,6 +200,8 @@ def get_idx_right_child(self):
204
200
205
201
206
202
class SplitNode (BaseNode ):
203
+ __slots__ = "index" , "idx_split_variable" , "split_value"
204
+
207
205
def __init__ (self , index , idx_split_variable , split_value ):
208
206
super ().__init__ (index )
209
207
@@ -212,6 +210,8 @@ def __init__(self, index, idx_split_variable, split_value):
212
210
213
211
214
212
class LeafNode (BaseNode ):
213
+ __slots__ = "index" , "value" , "idx_data_points"
214
+
215
215
def __init__ (self , index , value , idx_data_points ):
216
216
super ().__init__ (index )
217
217
self .value = value
0 commit comments