@@ -32,7 +32,7 @@ class Tree:
32
32
A dictionary that represents the nodes stored in breadth-first order, based in the array
33
33
method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays).
34
34
The dictionary's keys are integers that represent the nodes position.
35
- The dictionary's values are objects of type SplitNode or LeafNode that represent the nodes
35
+ The dictionary's values are objects of type Node that represent the split and leaf nodes
36
36
of the tree itself.
37
37
idx_leaf_nodes : list
38
38
List with the index of the leaf nodes of the tree.
@@ -56,7 +56,7 @@ class Tree:
56
56
57
57
def __init__ (self , leaf_node_value , idx_data_points , num_observations , shape ):
58
58
self .tree_structure = {
59
- 0 : LeafNode ( index = 0 , value = leaf_node_value , idx_data_points = idx_data_points )
59
+ 0 : Node . new_leaf_node ( 0 , value = leaf_node_value , idx_data_points = idx_data_points )
60
60
}
61
61
self .idx_leaf_nodes = [0 ]
62
62
self .output = np .zeros ((num_observations , shape )).astype (config .floatX ).squeeze ()
@@ -70,12 +70,12 @@ def __setitem__(self, index, node):
70
70
def copy (self ):
71
71
return deepcopy (self )
72
72
73
- def get_node (self , index ):
73
+ def get_node (self , index ) -> "Node" :
74
74
return self .tree_structure [index ]
75
75
76
76
def set_node (self , index , node ):
77
77
self .tree_structure [index ] = node
78
- if isinstance ( node , LeafNode ):
78
+ if node . is_leaf_node ( ):
79
79
self .idx_leaf_nodes .append (index )
80
80
81
81
def delete_leaf_node (self , index ):
@@ -89,15 +89,13 @@ def trim(self):
89
89
for k in a_tree .tree_structure .keys ():
90
90
current_node = a_tree [k ]
91
91
del current_node .depth
92
- if isinstance ( current_node , LeafNode ):
92
+ if current_node . is_leaf_node ( ):
93
93
del current_node .idx_data_points
94
94
return a_tree
95
95
96
96
def get_split_variables (self ):
97
97
return [
98
- node .idx_split_variable
99
- for node in self .tree_structure .values ()
100
- if isinstance (node , SplitNode )
98
+ node .idx_split_variable for node in self .tree_structure .values () if node .is_split_node ()
101
99
]
102
100
103
101
def _predict (self ):
@@ -115,6 +113,8 @@ def predict(self, x, excluded=None):
115
113
----------
116
114
x : numpy array
117
115
Unobserved point
116
+ excluded: list
117
+ Indexes of the variables to exclude when computing predictions
118
118
119
119
Returns
120
120
-------
@@ -123,12 +123,7 @@ def predict(self, x, excluded=None):
123
123
"""
124
124
if excluded is None :
125
125
excluded = []
126
- node = self ._traverse_tree (x , 0 , excluded )
127
- if isinstance (node , LeafNode ):
128
- leaf_value = node .value
129
- else :
130
- leaf_value = node
131
- return leaf_value
126
+ return self ._traverse_tree (x , 0 , excluded )
132
127
133
128
def _traverse_tree (self , x , node_index , excluded ):
134
129
"""
@@ -141,22 +136,22 @@ def _traverse_tree(self, x, node_index, excluded):
141
136
142
137
Returns
143
138
-------
144
- LeafNode or mean of leaf node values
139
+ Leaf node value or mean of leaf node values
145
140
"""
146
141
current_node = self .get_node (node_index )
147
- if isinstance ( current_node , SplitNode ):
148
- if current_node .idx_split_variable in excluded :
149
- leaf_values = []
150
- self . _traverse_leaf_values ( leaf_values , node_index )
151
- return np . mean (leaf_values , 0 )
152
-
153
- if x [ current_node . idx_split_variable ] <= current_node . split_value :
154
- left_child = current_node .get_idx_left_child ()
155
- current_node = self . _traverse_tree ( x , left_child , excluded )
156
- else :
157
- right_child = current_node . get_idx_right_child ()
158
- current_node = self . _traverse_tree ( x , right_child , excluded )
159
- return current_node
142
+ if current_node . is_leaf_node ( ):
143
+ return current_node .value
144
+ if current_node . idx_split_variable in excluded :
145
+ leaf_values = []
146
+ self . _traverse_leaf_values (leaf_values , node_index )
147
+ return np . mean ( leaf_values , 0 )
148
+
149
+ if x [ current_node . idx_split_variable ] < = current_node .value :
150
+ left_child = current_node . get_idx_left_child ( )
151
+ return self . _traverse_tree ( x , left_child , excluded )
152
+ else :
153
+ right_child = current_node . get_idx_right_child ( )
154
+ return self . _traverse_tree ( x , right_child , excluded )
160
155
161
156
def _traverse_leaf_values (self , leaf_values , node_index ):
162
157
"""
@@ -170,47 +165,43 @@ def _traverse_leaf_values(self, leaf_values, node_index):
170
165
-------
171
166
List of leaf node values
172
167
"""
173
- current_node = self .get_node (node_index )
174
- if isinstance (current_node , SplitNode ):
175
- left_child = current_node .get_idx_left_child ()
176
- self ._traverse_leaf_values (leaf_values , left_child )
177
- right_child = current_node .get_idx_right_child ()
178
- self ._traverse_leaf_values (leaf_values , right_child )
168
+ node = self .get_node (node_index )
169
+ if node .is_leaf_node ():
170
+ leaf_values .append (node .value )
179
171
else :
180
- leaf_values .append (current_node .value )
172
+ self ._traverse_leaf_values (leaf_values , node .get_idx_left_child ())
173
+ self ._traverse_leaf_values (leaf_values , node .get_idx_right_child ())
181
174
182
175
183
- class BaseNode :
184
- __slots__ = "index" , "depth"
176
+ class Node :
177
+ __slots__ = "index" , "depth" , "value" , "idx_split_variable" , "idx_data_points"
185
178
186
- def __init__ (self , index ):
179
+ def __init__ (self , index : int , value = - 1 , idx_data_points = None , idx_split_variable = - 1 ):
187
180
self .index = index
188
181
self .depth = int (math .floor (math .log (index + 1 , 2 )))
182
+ self .value = value
183
+ self .idx_data_points = idx_data_points
184
+ self .idx_split_variable = idx_split_variable
189
185
190
- def get_idx_parent_node (self ):
186
+ @classmethod
187
+ def new_leaf_node (cls , index : int , value , idx_data_points ) -> "Node" :
188
+ return cls (index , value = value , idx_data_points = idx_data_points )
189
+
190
+ @classmethod
191
+ def new_split_node (cls , index : int , split_value , idx_split_variable ) -> "Node" :
192
+ return cls (index , value = split_value , idx_split_variable = idx_split_variable )
193
+
194
+ def get_idx_parent_node (self ) -> int :
191
195
return (self .index - 1 ) // 2
192
196
193
- def get_idx_left_child (self ):
197
+ def get_idx_left_child (self ) -> int :
194
198
return self .index * 2 + 1
195
199
196
- def get_idx_right_child (self ):
200
+ def get_idx_right_child (self ) -> int :
197
201
return self .get_idx_left_child () + 1
198
202
203
+ def is_split_node (self ) -> bool :
204
+ return self .idx_split_variable >= 0
199
205
200
- class SplitNode (BaseNode ):
201
- __slots__ = "idx_split_variable" , "split_value"
202
-
203
- def __init__ (self , index , idx_split_variable , split_value ):
204
- super ().__init__ (index )
205
-
206
- self .idx_split_variable = idx_split_variable
207
- self .split_value = split_value
208
-
209
-
210
- class LeafNode (BaseNode ):
211
- __slots__ = "value" , "idx_data_points"
212
-
213
- def __init__ (self , index , value , idx_data_points ):
214
- super ().__init__ (index )
215
- self .value = value
216
- self .idx_data_points = idx_data_points
206
+ def is_leaf_node (self ) -> bool :
207
+ return not self .is_split_node ()
0 commit comments