Skip to content

Commit e1d4177

Browse files
committed
Make classes generic
1 parent a37858d commit e1d4177

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

source-code/typing/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ Type checking can be done using [mypy](http://mypy-lang.org/index.html).
4242
1. `classes.py`: illustration of using type hints with a user-defined class.
4343
1. `classes_incorrect.py`: illustration of using type hints with a user-defined
4444
class with errors.
45-
1. `tree.py`: illustration of using type hints on more sophisticated classes.
45+
1. `tree.py`: illustration of using type hints on more sophisticated classes, as well as generic types.

source-code/typing/tree.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
from typing import Any, Callable, Self
55

66

7-
type TransformFunc = Callable[[Node], None]
8-
type VisitFunc = Callable[[Node], Any]
7+
type TransformFunc[T] = Callable[[Node[T]], None]
8+
type VisitFunc[T] = Callable[[Node[T]], Any]
99
type AggrFunc = Callable[[Any, Any, Any], Any]
1010

1111

12-
class Node:
12+
class Node[T]:
1313
_left: Self | None
1414
_right: Self | None
15-
_data: int
15+
_data: T
1616

17-
def __init__(self, data: int):
17+
def __init__(self, data: T):
1818
self._left = None
1919
self._right = None
2020
self._data = data
@@ -36,11 +36,11 @@ def right(self, right: Self) -> None:
3636
self._right = right
3737

3838
@property
39-
def data(self) -> int:
39+
def data(self) -> T:
4040
return self._data
4141

4242
@data.setter
43-
def data(self, data: int) -> None:
43+
def data(self, data: T) -> None:
4444
self._data = data
4545

4646
@property
@@ -64,7 +64,7 @@ def nr_descendants(self) -> int:
6464
count += 1 + self._right.nr_descendants
6565
return count
6666

67-
def transformn(self, func: TransformFunc) -> None:
67+
def transformn(self, func: TransformFunc[T]) -> None:
6868
func(self)
6969
if self._left is not None:
7070
self._left.transformn(func)
@@ -77,10 +77,12 @@ def __str__(self) -> str:
7777
def __repr__(self) -> str:
7878
return f"{self.data}"
7979

80-
def visit(self, visit_func: VisitFunc, aggr_func: AggrFunc) -> Any:
80+
def visit(self, visit_func: VisitFunc[T], aggr_func: AggrFunc) -> Any:
8181
self_value = visit_func(self)
8282
left_value = (
83-
self._left.visit(visit_func, aggr_func) if self._left is not None else None
83+
self._left.visit(visit_func, aggr_func)
84+
if self._left is not None
85+
else None
8486
)
8587
right_value = (
8688
self._right.visit(visit_func, aggr_func)
@@ -90,7 +92,7 @@ def visit(self, visit_func: VisitFunc, aggr_func: AggrFunc) -> Any:
9092
return aggr_func(self_value, left_value, right_value)
9193

9294

93-
def str_visit(node: Node) -> str:
95+
def str_visit[T](node: Node[T]) -> str:
9496
return str(node.data)
9597

9698

@@ -104,34 +106,38 @@ def str_aggr(self_value: str, left_value: str, right_value: str) -> str:
104106
return aggr
105107

106108

107-
def double_value(node: Node) -> None:
109+
def double_value(node: Node[int]) -> None:
108110
node.data = 2 * node.data
109111

110112

111-
class Tree:
112-
_root: Node | None
113+
class Tree[T]:
114+
_root: Node[T] | None
113115

114-
def __init__(self, root: Node | None = None):
116+
def __init__(self, root: Node[T] | None = None):
115117
self._root = root
116118

117119
@property
118-
def root(self) -> Node | None:
120+
def root(self) -> Node[T] | None:
119121
return self._root
120122

121123
@root.setter
122-
def root(self, root: Node) -> None:
124+
def root(self, root: Node[T]) -> None:
123125
self._root = root
124126

125127
@property
126128
def nr_of_nodes(self) -> int:
127129
return 0 if self._root is None else 1 + self._root.nr_descendants
128130

129-
def transformn(self, func: TransformFunc) -> None:
131+
def transformn(self, func: TransformFunc[T]) -> None:
130132
if self._root is not None:
131133
self._root.transformn(func)
132134

133135
def __str__(self) -> str:
134-
return "" if self._root is None else self._root.visit(str_visit, str_aggr)
136+
return (
137+
""
138+
if self._root is None
139+
else self._root.visit(str_visit, str_aggr)
140+
)
135141

136142
def __repr__(self) -> str:
137143
return f"{self._root}"

0 commit comments

Comments
 (0)