4
4
from typing import Any , Callable , Self
5
5
6
6
7
- type TransformFunc = Callable [[Node ], None ]
8
- type VisitFunc = Callable [[Node ], Any ]
7
+ type TransformFunc [ T ] = Callable [[Node [ T ] ], None ]
8
+ type VisitFunc [ T ] = Callable [[Node [ T ] ], Any ]
9
9
type AggrFunc = Callable [[Any , Any , Any ], Any ]
10
10
11
11
12
- class Node :
12
+ class Node [ T ] :
13
13
_left : Self | None
14
14
_right : Self | None
15
- _data : int
15
+ _data : T
16
16
17
- def __init__ (self , data : int ):
17
+ def __init__ (self , data : T ):
18
18
self ._left = None
19
19
self ._right = None
20
20
self ._data = data
@@ -36,11 +36,11 @@ def right(self, right: Self) -> None:
36
36
self ._right = right
37
37
38
38
@property
39
- def data (self ) -> int :
39
+ def data (self ) -> T :
40
40
return self ._data
41
41
42
42
@data .setter
43
- def data (self , data : int ) -> None :
43
+ def data (self , data : T ) -> None :
44
44
self ._data = data
45
45
46
46
@property
@@ -64,7 +64,7 @@ def nr_descendants(self) -> int:
64
64
count += 1 + self ._right .nr_descendants
65
65
return count
66
66
67
- def transformn (self , func : TransformFunc ) -> None :
67
+ def transformn (self , func : TransformFunc [ T ] ) -> None :
68
68
func (self )
69
69
if self ._left is not None :
70
70
self ._left .transformn (func )
@@ -77,10 +77,12 @@ def __str__(self) -> str:
77
77
def __repr__ (self ) -> str :
78
78
return f"{ self .data } "
79
79
80
- def visit (self , visit_func : VisitFunc , aggr_func : AggrFunc ) -> Any :
80
+ def visit (self , visit_func : VisitFunc [ T ] , aggr_func : AggrFunc ) -> Any :
81
81
self_value = visit_func (self )
82
82
left_value = (
83
- self ._left .visit (visit_func , aggr_func ) if self ._left is not None else None
83
+ self ._left .visit (visit_func , aggr_func )
84
+ if self ._left is not None
85
+ else None
84
86
)
85
87
right_value = (
86
88
self ._right .visit (visit_func , aggr_func )
@@ -90,7 +92,7 @@ def visit(self, visit_func: VisitFunc, aggr_func: AggrFunc) -> Any:
90
92
return aggr_func (self_value , left_value , right_value )
91
93
92
94
93
- def str_visit (node : Node ) -> str :
95
+ def str_visit [ T ] (node : Node [ T ] ) -> str :
94
96
return str (node .data )
95
97
96
98
@@ -104,34 +106,38 @@ def str_aggr(self_value: str, left_value: str, right_value: str) -> str:
104
106
return aggr
105
107
106
108
107
- def double_value (node : Node ) -> None :
109
+ def double_value (node : Node [ int ] ) -> None :
108
110
node .data = 2 * node .data
109
111
110
112
111
- class Tree :
112
- _root : Node | None
113
+ class Tree [ T ] :
114
+ _root : Node [ T ] | None
113
115
114
- def __init__ (self , root : Node | None = None ):
116
+ def __init__ (self , root : Node [ T ] | None = None ):
115
117
self ._root = root
116
118
117
119
@property
118
- def root (self ) -> Node | None :
120
+ def root (self ) -> Node [ T ] | None :
119
121
return self ._root
120
122
121
123
@root .setter
122
- def root (self , root : Node ) -> None :
124
+ def root (self , root : Node [ T ] ) -> None :
123
125
self ._root = root
124
126
125
127
@property
126
128
def nr_of_nodes (self ) -> int :
127
129
return 0 if self ._root is None else 1 + self ._root .nr_descendants
128
130
129
- def transformn (self , func : TransformFunc ) -> None :
131
+ def transformn (self , func : TransformFunc [ T ] ) -> None :
130
132
if self ._root is not None :
131
133
self ._root .transformn (func )
132
134
133
135
def __str__ (self ) -> str :
134
- return "" if self ._root is None else self ._root .visit (str_visit , str_aggr )
136
+ return (
137
+ ""
138
+ if self ._root is None
139
+ else self ._root .visit (str_visit , str_aggr )
140
+ )
135
141
136
142
def __repr__ (self ) -> str :
137
143
return f"{ self ._root } "
0 commit comments