12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ from collections .abc import Generator
15
16
from functools import lru_cache
16
- from typing import Dict , Generator , List , Optional , Tuple , Union
17
+ from typing import Optional , Union
17
18
18
19
import numpy as np
19
20
import numpy .typing as npt
@@ -30,7 +31,7 @@ class Node:
30
31
value : npt.NDArray[np.float64]
31
32
idx_data_points : Optional[npt.NDArray[np.int_]]
32
33
idx_split_variable : int
33
- linear_params: Optional[List [float]] = None
34
+ linear_params: Optional[list [float]] = None
34
35
"""
35
36
36
37
__slots__ = "value" , "nvalue" , "idx_split_variable" , "idx_data_points" , "linear_params"
@@ -41,7 +42,7 @@ def __init__(
41
42
nvalue : int = 0 ,
42
43
idx_data_points : Optional [npt .NDArray [np .int_ ]] = None ,
43
44
idx_split_variable : int = - 1 ,
44
- linear_params : Optional [List [npt .NDArray [np .float64 ]]] = None ,
45
+ linear_params : Optional [list [npt .NDArray [np .float64 ]]] = None ,
45
46
) -> None :
46
47
self .value = value
47
48
self .nvalue = nvalue
@@ -56,7 +57,7 @@ def new_leaf_node(
56
57
nvalue : int = 0 ,
57
58
idx_data_points : Optional [npt .NDArray [np .int_ ]] = None ,
58
59
idx_split_variable : int = - 1 ,
59
- linear_params : Optional [List [npt .NDArray [np .float64 ]]] = None ,
60
+ linear_params : Optional [list [npt .NDArray [np .float64 ]]] = None ,
60
61
) -> "Node" :
61
62
return cls (
62
63
value = value ,
@@ -94,19 +95,19 @@ class Tree:
94
95
95
96
Attributes
96
97
----------
97
- tree_structure : Dict [int, Node]
98
+ tree_structure : dict [int, Node]
98
99
A dictionary that represents the nodes stored in breadth-first order, based in the array
99
100
method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays).
100
101
The dictionary's keys are integers that represent the nodes position.
101
102
The dictionary's values are objects of type Node that represent the split and leaf nodes
102
103
of the tree itself.
103
104
output: Optional[npt.NDArray[np.float64]]
104
105
Array of shape number of observations, shape
105
- split_rules : List [SplitRule]
106
+ split_rules : list [SplitRule]
106
107
List of SplitRule objects, one per column in input data.
107
108
Allows using different split rules for different columns. Default is ContinuousSplitRule.
108
109
Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
109
- idx_leaf_nodes : Optional[List [int]], by default None.
110
+ idx_leaf_nodes : Optional[list [int]], by default None.
110
111
Array with the index of the leaf nodes of the tree.
111
112
112
113
Parameters
@@ -120,10 +121,10 @@ class Tree:
120
121
121
122
def __init__ (
122
123
self ,
123
- tree_structure : Dict [int , Node ],
124
+ tree_structure : dict [int , Node ],
124
125
output : npt .NDArray [np .float64 ],
125
- split_rules : List [SplitRule ],
126
- idx_leaf_nodes : Optional [List [int ]] = None ,
126
+ split_rules : list [SplitRule ],
127
+ idx_leaf_nodes : Optional [list [int ]] = None ,
127
128
) -> None :
128
129
self .tree_structure = tree_structure
129
130
self .idx_leaf_nodes = idx_leaf_nodes
@@ -137,7 +138,7 @@ def new_tree(
137
138
idx_data_points : Optional [npt .NDArray [np .int_ ]],
138
139
num_observations : int ,
139
140
shape : int ,
140
- split_rules : List [SplitRule ],
141
+ split_rules : list [SplitRule ],
141
142
) -> "Tree" :
142
143
return cls (
143
144
tree_structure = {
@@ -159,7 +160,7 @@ def __setitem__(self, index, node) -> None:
159
160
self .set_node (index , node )
160
161
161
162
def copy (self ) -> "Tree" :
162
- tree : Dict [int , Node ] = {
163
+ tree : dict [int , Node ] = {
163
164
k : Node (
164
165
value = v .value ,
165
166
nvalue = v .nvalue ,
@@ -199,7 +200,7 @@ def grow_leaf_node(
199
200
self .idx_leaf_nodes .remove (index_leaf_node )
200
201
201
202
def trim (self ) -> "Tree" :
202
- tree : Dict [int , Node ] = {
203
+ tree : dict [int , Node ] = {
203
204
k : Node (
204
205
value = v .value ,
205
206
nvalue = v .nvalue ,
@@ -233,7 +234,7 @@ def _predict(self) -> npt.NDArray[np.float64]:
233
234
def predict (
234
235
self ,
235
236
x : npt .NDArray [np .float64 ],
236
- excluded : Optional [List [int ]] = None ,
237
+ excluded : Optional [list [int ]] = None ,
237
238
shape : int = 1 ,
238
239
) -> npt .NDArray [np .float64 ]:
239
240
"""
@@ -243,7 +244,7 @@ def predict(
243
244
----------
244
245
x : npt.NDArray[np.float64]
245
246
Unobserved point
246
- excluded: Optional[List [int]]
247
+ excluded: Optional[list [int]]
247
248
Indexes of the variables to exclude when computing predictions
248
249
249
250
Returns
@@ -259,8 +260,8 @@ def predict(
259
260
def _traverse_tree (
260
261
self ,
261
262
X : npt .NDArray [np .float64 ],
262
- excluded : Optional [List [int ]] = None ,
263
- shape : Union [int , Tuple [int , ...]] = 1 ,
263
+ excluded : Optional [list [int ]] = None ,
264
+ shape : Union [int , tuple [int , ...]] = 1 ,
264
265
) -> npt .NDArray [np .float64 ]:
265
266
"""
266
267
Traverse the tree starting from the root node given an (un)observed point.
@@ -273,7 +274,7 @@ def _traverse_tree(
273
274
Index of the node to start the traversal from
274
275
split_variable : int
275
276
Index of the variable used to split the node
276
- excluded: Optional[List [int]]
277
+ excluded: Optional[list [int]]
277
278
Indexes of the variables to exclude when computing predictions
278
279
279
280
Returns
@@ -327,14 +328,14 @@ def _traverse_tree(
327
328
return p_d
328
329
329
330
def _traverse_leaf_values (
330
- self , leaf_values : List [npt .NDArray [np .float64 ]], leaf_n_values : List [int ], node_index : int
331
+ self , leaf_values : list [npt .NDArray [np .float64 ]], leaf_n_values : list [int ], node_index : int
331
332
) -> None :
332
333
"""
333
334
Traverse the tree appending leaf values starting from a particular node.
334
335
335
336
Parameters
336
337
----------
337
- leaf_values : List [npt.NDArray[np.float64]]
338
+ leaf_values : list [npt.NDArray[np.float64]]
338
339
node_index : int
339
340
"""
340
341
node = self .get_node (node_index )
0 commit comments