1+ r"""
2+ A binary search Tree
3+
4+ Example
5+ 8
6+ / \
7+ 3 10
8+ / \ \
9+ 1 6 14
10+ / \ /
11+ 4 7 13
12+
13+ >>> t = BinarySearchTree().insert(8, 3, 6, 1, 10, 14, 13, 4, 7)
14+ >>> print(" ".join(repr(i.value) for i in t.traversal_tree()))
15+ 8 3 1 6 4 7 10 14 13
16+
17+ >>> tuple(i.value for i in t.traversal_tree(inorder))
18+ (1, 3, 4, 6, 7, 8, 10, 13, 14)
19+ >>> tuple(t)
20+ (1, 3, 4, 6, 7, 8, 10, 13, 14)
21+ >>> t.find_kth_smallest(3, t.root)
22+ 4
23+ >>> tuple(t)[3-1]
24+ 4
25+
26+ >>> print(" ".join(repr(i.value) for i in t.traversal_tree(postorder))
27+ 1 4 7 6 3 13 14 10 8
28+ >>> t.remove(20)
29+ Traceback (most recent call last):
30+ ...
31+ ValueError: Value 20 not found
32+ >>> BinarySearchTree().search(6)
33+ Traceback (most recent call last):
34+ ...
35+ IndexError: Warning: Tree is empty! please use another.
36+
37+ Other example:
38+
39+ >>> testlist = (8, 3, 6, 1, 10, 14, 13, 4, 7)
40+ >>> t = BinarySearchTree()
41+ >>> for i in testlist:
42+ ... t.insert(i) # doctest: +ELLIPSIS
43+ BinarySearchTree(root=8)
44+ BinarySearchTree(root={'8': (3, None)})
45+ BinarySearchTree(root={'8': ({'3': (None, 6)}, None)})
46+ BinarySearchTree(root={'8': ({'3': (1, 6)}, None)})
47+ BinarySearchTree(root={'8': ({'3': (1, 6)}, 10)})
48+ BinarySearchTree(root={'8': ({'3': (1, 6)}, {'10': (None, 14)})})
49+ BinarySearchTree(root={'8': ({'3': (1, 6)}, {'10': (None, {'14': (13, None)})})})
50+ BinarySearchTree(root={'8': ({'3': (1, {'6': (4, None)})}, {'10': (None, {'14': ...
51+ BinarySearchTree(root={'8': ({'3': (1, {'6': (4, 7)})}, {'10': (None, {'14': (13, ...
52+
53+ Prints all the elements of the list in order traversal
54+ >>> print(t)
55+ {'8': ({'3': (1, {'6': (4, 7)})}, {'10': (None, {'14': (13, None)})})}
56+
57+ Test existence
58+ >>> t.search(6) is not None
59+ True
60+ >>> 6 in t
61+ True
62+ >>> t.search(-1) is not None
63+ False
64+ >>> -1 in t
65+ False
66+
67+ >>> t.search(6).is_right
68+ True
69+ >>> t.search(1).is_right
70+ False
71+
72+ >>> t.get_max().value
73+ 14
74+ >>> max(t)
75+ 14
76+ >>> t.get_min().value
77+ 1
78+ >>> min(t)
79+ 1
80+ >>> t.empty()
81+ False
82+ >>> not t
83+ False
84+ >>> for i in testlist:
85+ ... t.remove(i)
86+ >>> t.empty()
87+ True
88+ >>> not t
89+ True
90+ """
191from __future__ import annotations
2- from pprint import pformat
3- from collections .abc import Iterable , Iterator
92+
493from dataclasses import dataclass
5- from typing import Any , Self
94+ from pprint import pformat
95+ from typing import Iterator
696
797
898@dataclass
999class Node :
10100 value : int
11101 left : Node | None = None
12102 right : Node | None = None
13- parent : Node | None = None
103+ parent : Node | None = None # For easier deletion
104+
105+ @property
106+ def is_right (self ) -> bool :
107+ return bool (self .parent and self is self .parent .right )
108+
109+ def __iter__ (self ) -> Iterator [int ]:
110+ if self .left :
111+ yield from self .left
112+ yield self .value
113+ if self .right :
114+ yield from self .right
14115
15116 def __repr__ (self ) -> str :
16117 if self .left is None and self .right is None :
@@ -22,9 +123,21 @@ def __repr__(self) -> str:
22123class BinarySearchTree :
23124 root : Node | None = None
24125
126+ def __bool__ (self ) -> bool :
127+ return self .root is not None
128+
129+ def __iter__ (self ) -> Iterator [int ]:
130+ if self .root :
131+ yield from self .root
132+ return iter (())
133+
134+ def __str__ (self ) -> str :
135+ return str (self .root ) if self .root else "Empty tree"
136+
25137 def __reassign_nodes (self , node : Node , new_children : Node | None ) -> None :
26138 if new_children is not None :
27139 new_children .parent = node .parent
140+
28141 if node .parent is not None :
29142 if node .is_right :
30143 node .parent .right = new_children
@@ -33,38 +146,71 @@ def __reassign_nodes(self, node: Node, new_children: Node | None) -> None:
33146 else :
34147 self .root = new_children
35148
36- def __insert (self , value ) -> None :
149+ def empty (self ) -> bool :
150+ return self .root is None
151+
152+ def __insert (self , value : int ) -> None :
37153 new_node = Node (value )
38154 if self .empty ():
39155 self .root = new_node
40- else :
41- parent_node = self .root
42- while True :
43- if value < parent_node .value :
44- if parent_node .left is None :
45- parent_node .left = new_node
46- break
47- else :
48- parent_node = parent_node .left
49- elif parent_node .right is None :
156+ return
157+
158+ parent_node = self .root
159+ while parent_node :
160+ if value < parent_node .value :
161+ if parent_node .left is None :
162+ parent_node .left = new_node
163+ new_node .parent = parent_node
164+ return
165+ parent_node = parent_node .left
166+ else :
167+ if parent_node .right is None :
50168 parent_node .right = new_node
51- break
52- else :
53- parent_node = parent_node .right
54- new_node .parent = parent_node
169+ new_node .parent = parent_node
170+ return
171+ parent_node = parent_node .right
55172
56- def search (self , value ) -> Node | None :
173+ def insert (self , * values : int ) -> BinarySearchTree :
174+ for value in values :
175+ self .__insert (value )
176+ return self
177+
178+ def search (self , value : int ) -> Node | None :
57179 if self .empty ():
58180 raise IndexError ("Warning: Tree is empty! please use another." )
181+
59182 node = self .root
60183 while node is not None and node .value != value :
61- node = node .left if value < node .value else node .right
184+ if value < node .value :
185+ node = node .left
186+ else :
187+ node = node .right
188+ return node
189+
190+ def get_max (self , node : Node | None = None ) -> Node | None :
191+ if node is None :
192+ node = self .root
193+ if node is None :
194+ return None
195+
196+ while node .right is not None :
197+ node = node .right
198+ return node
199+ def get_min (self , node : Node | None = None ) -> Node | None :
200+ if node is None :
201+ node = self .root
202+ if node is None :
203+ return None
204+
205+ while node .left is not None :
206+ node = node .left
62207 return node
63208
64209 def remove (self , value : int ) -> None :
65210 node = self .search (value )
66211 if node is None :
67- raise ValueError (f"Value { value } not found" )
212+ error_msg = f"Value { value } not found"
213+ raise ValueError (error_msg )
68214
69215 if node .left is None and node .right is None :
70216 self .__reassign_nodes (node , None )
@@ -74,23 +220,45 @@ def remove(self, value: int) -> None:
74220 self .__reassign_nodes (node , node .left )
75221 else :
76222 predecessor = self .get_max (node .left )
77- self .remove (predecessor .value )
78- node .value = predecessor .value
223+ if predecessor is not None :
224+ self .remove (predecessor .value )
225+ node .value = predecessor .value
226+
227+ def preorder_traverse (self , node : Node | None ) -> Iterator [Node ]:
228+ if node is not None :
229+ yield node
230+ yield from self .preorder_traverse (node .left )
231+ yield from self .preorder_traverse (node .right )
232+
233+ def traversal_tree (self , traversal_function = None ) -> Iterator [Node ]:
234+ if traversal_function is None :
235+ return self .preorder_traverse (self .root )
236+ return traversal_function (self .root )
237+
238+ def inorder (self , arr : list [int ], node : Node | None ) -> None :
239+ if node :
240+ self .inorder (arr , node .left )
241+ arr .append (node .value )
242+ self .inorder (arr , node .right )
243+
244+ def find_kth_smallest (self , k : int , node : Node ) -> int :
245+ arr : list [int ] = []
246+ self .inorder (arr , node )
247+ return arr [k - 1 ]
79248
80249
81- # 修复的递归函数
82250def inorder (curr_node : Node | None ) -> list [Node ]:
83251 """Inorder traversal (left, self, right)"""
84252 if curr_node is None :
85253 return []
86- return inorder (curr_node .left ) + [ curr_node ] + inorder (curr_node .right )
254+ return [ * inorder (curr_node .left ), curr_node , * inorder (curr_node .right )]
87255
88256
89257def postorder (curr_node : Node | None ) -> list [Node ]:
90258 """Postorder traversal (left, right, self)"""
91259 if curr_node is None :
92260 return []
93- return postorder (curr_node .left ) + postorder (curr_node .right ) + [ curr_node ]
261+ return [ * postorder (curr_node .left ), * postorder (curr_node .right ), curr_node ]
94262
95263
96264if __name__ == "__main__" :
0 commit comments