1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from collections .abc import Generator
1516from functools import lru_cache
16- from typing import Dict , Generator , List , Optional , Tuple , Union
17+ from typing import Optional , Union
1718
1819import numpy as np
1920import numpy .typing as npt
@@ -30,7 +31,7 @@ class Node:
3031 value : npt.NDArray[np.float64]
3132 idx_data_points : Optional[npt.NDArray[np.int_]]
3233 idx_split_variable : int
33- linear_params: Optional[List [float]] = None
34+ linear_params: Optional[list [float]] = None
3435 """
3536
3637 __slots__ = "value" , "nvalue" , "idx_split_variable" , "idx_data_points" , "linear_params"
@@ -41,7 +42,7 @@ def __init__(
4142 nvalue : int = 0 ,
4243 idx_data_points : Optional [npt .NDArray [np .int_ ]] = None ,
4344 idx_split_variable : int = - 1 ,
44- linear_params : Optional [List [npt .NDArray [np .float64 ]]] = None ,
45+ linear_params : Optional [list [npt .NDArray [np .float64 ]]] = None ,
4546 ) -> None :
4647 self .value = value
4748 self .nvalue = nvalue
@@ -56,7 +57,7 @@ def new_leaf_node(
5657 nvalue : int = 0 ,
5758 idx_data_points : Optional [npt .NDArray [np .int_ ]] = None ,
5859 idx_split_variable : int = - 1 ,
59- linear_params : Optional [List [npt .NDArray [np .float64 ]]] = None ,
60+ linear_params : Optional [list [npt .NDArray [np .float64 ]]] = None ,
6061 ) -> "Node" :
6162 return cls (
6263 value = value ,
@@ -94,19 +95,19 @@ class Tree:
9495
9596 Attributes
9697 ----------
97- tree_structure : Dict [int, Node]
98+ tree_structure : dict [int, Node]
9899 A dictionary that represents the nodes stored in breadth-first order, based in the array
99100 method for storing binary trees (https://en.wikipedia.org/wiki/Binary_tree#Arrays).
100101 The dictionary's keys are integers that represent the nodes position.
101102 The dictionary's values are objects of type Node that represent the split and leaf nodes
102103 of the tree itself.
103104 output: Optional[npt.NDArray[np.float64]]
104105 Array of shape number of observations, shape
105- split_rules : List [SplitRule]
106+ split_rules : list [SplitRule]
106107 List of SplitRule objects, one per column in input data.
107108 Allows using different split rules for different columns. Default is ContinuousSplitRule.
108109 Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables.
109- idx_leaf_nodes : Optional[List [int]], by default None.
110+ idx_leaf_nodes : Optional[list [int]], by default None.
110111 Array with the index of the leaf nodes of the tree.
111112
112113 Parameters
@@ -120,10 +121,10 @@ class Tree:
120121
121122 def __init__ (
122123 self ,
123- tree_structure : Dict [int , Node ],
124+ tree_structure : dict [int , Node ],
124125 output : npt .NDArray [np .float64 ],
125- split_rules : List [SplitRule ],
126- idx_leaf_nodes : Optional [List [int ]] = None ,
126+ split_rules : list [SplitRule ],
127+ idx_leaf_nodes : Optional [list [int ]] = None ,
127128 ) -> None :
128129 self .tree_structure = tree_structure
129130 self .idx_leaf_nodes = idx_leaf_nodes
@@ -137,7 +138,7 @@ def new_tree(
137138 idx_data_points : Optional [npt .NDArray [np .int_ ]],
138139 num_observations : int ,
139140 shape : int ,
140- split_rules : List [SplitRule ],
141+ split_rules : list [SplitRule ],
141142 ) -> "Tree" :
142143 return cls (
143144 tree_structure = {
@@ -159,7 +160,7 @@ def __setitem__(self, index, node) -> None:
159160 self .set_node (index , node )
160161
161162 def copy (self ) -> "Tree" :
162- tree : Dict [int , Node ] = {
163+ tree : dict [int , Node ] = {
163164 k : Node (
164165 value = v .value ,
165166 nvalue = v .nvalue ,
@@ -199,7 +200,7 @@ def grow_leaf_node(
199200 self .idx_leaf_nodes .remove (index_leaf_node )
200201
201202 def trim (self ) -> "Tree" :
202- tree : Dict [int , Node ] = {
203+ tree : dict [int , Node ] = {
203204 k : Node (
204205 value = v .value ,
205206 nvalue = v .nvalue ,
@@ -233,7 +234,7 @@ def _predict(self) -> npt.NDArray[np.float64]:
233234 def predict (
234235 self ,
235236 x : npt .NDArray [np .float64 ],
236- excluded : Optional [List [int ]] = None ,
237+ excluded : Optional [list [int ]] = None ,
237238 shape : int = 1 ,
238239 ) -> npt .NDArray [np .float64 ]:
239240 """
@@ -243,7 +244,7 @@ def predict(
243244 ----------
244245 x : npt.NDArray[np.float64]
245246 Unobserved point
246- excluded: Optional[List [int]]
247+ excluded: Optional[list [int]]
247248 Indexes of the variables to exclude when computing predictions
248249
249250 Returns
@@ -259,8 +260,8 @@ def predict(
259260 def _traverse_tree (
260261 self ,
261262 X : npt .NDArray [np .float64 ],
262- excluded : Optional [List [int ]] = None ,
263- shape : Union [int , Tuple [int , ...]] = 1 ,
263+ excluded : Optional [list [int ]] = None ,
264+ shape : Union [int , tuple [int , ...]] = 1 ,
264265 ) -> npt .NDArray [np .float64 ]:
265266 """
266267 Traverse the tree starting from the root node given an (un)observed point.
@@ -273,7 +274,7 @@ def _traverse_tree(
273274 Index of the node to start the traversal from
274275 split_variable : int
275276 Index of the variable used to split the node
276- excluded: Optional[List [int]]
277+ excluded: Optional[list [int]]
277278 Indexes of the variables to exclude when computing predictions
278279
279280 Returns
@@ -327,14 +328,14 @@ def _traverse_tree(
327328 return p_d
328329
329330 def _traverse_leaf_values (
330- self , leaf_values : List [npt .NDArray [np .float64 ]], leaf_n_values : List [int ], node_index : int
331+ self , leaf_values : list [npt .NDArray [np .float64 ]], leaf_n_values : list [int ], node_index : int
331332 ) -> None :
332333 """
333334 Traverse the tree appending leaf values starting from a particular node.
334335
335336 Parameters
336337 ----------
337- leaf_values : List [npt.NDArray[np.float64]]
338+ leaf_values : list [npt.NDArray[np.float64]]
338339 node_index : int
339340 """
340341 node = self .get_node (node_index )
0 commit comments