Skip to content

Commit 1490748

Browse files
authored
Update binary_search_tree.py
1 parent 57c44ac commit 1490748

File tree

1 file changed

+195
-27
lines changed

1 file changed

+195
-27
lines changed

data_structures/binary_tree/binary_search_tree.py

Lines changed: 195 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,117 @@
1+
r"""
2+
A binary search Tree
3+
4+
Example
5+
8
6+
/ \
7+
3 10
8+
/ \ \
9+
1 6 14
10+
/ \ /
11+
4 7 13
12+
13+
>>> t = BinarySearchTree().insert(8, 3, 6, 1, 10, 14, 13, 4, 7)
14+
>>> print(" ".join(repr(i.value) for i in t.traversal_tree()))
15+
8 3 1 6 4 7 10 14 13
16+
17+
>>> tuple(i.value for i in t.traversal_tree(inorder))
18+
(1, 3, 4, 6, 7, 8, 10, 13, 14)
19+
>>> tuple(t)
20+
(1, 3, 4, 6, 7, 8, 10, 13, 14)
21+
>>> t.find_kth_smallest(3, t.root)
22+
4
23+
>>> tuple(t)[3-1]
24+
4
25+
26+
>>> print(" ".join(repr(i.value) for i in t.traversal_tree(postorder))
27+
1 4 7 6 3 13 14 10 8
28+
>>> t.remove(20)
29+
Traceback (most recent call last):
30+
...
31+
ValueError: Value 20 not found
32+
>>> BinarySearchTree().search(6)
33+
Traceback (most recent call last):
34+
...
35+
IndexError: Warning: Tree is empty! please use another.
36+
37+
Other example:
38+
39+
>>> testlist = (8, 3, 6, 1, 10, 14, 13, 4, 7)
40+
>>> t = BinarySearchTree()
41+
>>> for i in testlist:
42+
... t.insert(i) # doctest: +ELLIPSIS
43+
BinarySearchTree(root=8)
44+
BinarySearchTree(root={'8': (3, None)})
45+
BinarySearchTree(root={'8': ({'3': (None, 6)}, None)})
46+
BinarySearchTree(root={'8': ({'3': (1, 6)}, None)})
47+
BinarySearchTree(root={'8': ({'3': (1, 6)}, 10)})
48+
BinarySearchTree(root={'8': ({'3': (1, 6)}, {'10': (None, 14)})})
49+
BinarySearchTree(root={'8': ({'3': (1, 6)}, {'10': (None, {'14': (13, None)})})})
50+
BinarySearchTree(root={'8': ({'3': (1, {'6': (4, None)})}, {'10': (None, {'14': ...
51+
BinarySearchTree(root={'8': ({'3': (1, {'6': (4, 7)})}, {'10': (None, {'14': (13, ...
52+
53+
Prints all the elements of the list in order traversal
54+
>>> print(t)
55+
{'8': ({'3': (1, {'6': (4, 7)})}, {'10': (None, {'14': (13, None)})})}
56+
57+
Test existence
58+
>>> t.search(6) is not None
59+
True
60+
>>> 6 in t
61+
True
62+
>>> t.search(-1) is not None
63+
False
64+
>>> -1 in t
65+
False
66+
67+
>>> t.search(6).is_right
68+
True
69+
>>> t.search(1).is_right
70+
False
71+
72+
>>> t.get_max().value
73+
14
74+
>>> max(t)
75+
14
76+
>>> t.get_min().value
77+
1
78+
>>> min(t)
79+
1
80+
>>> t.empty()
81+
False
82+
>>> not t
83+
False
84+
>>> for i in testlist:
85+
... t.remove(i)
86+
>>> t.empty()
87+
True
88+
>>> not t
89+
True
90+
"""
191
from __future__ import annotations
2-
from pprint import pformat
3-
from collections.abc import Iterable, Iterator
92+
493
from dataclasses import dataclass
5-
from typing import Any, Self
94+
from pprint import pformat
95+
from typing import Iterator
696

797

898
@dataclass
999
class Node:
10100
value: int
11101
left: Node | None = None
12102
right: Node | None = None
13-
parent: Node | None = None
103+
parent: Node | None = None # For easier deletion
104+
105+
@property
106+
def is_right(self) -> bool:
107+
return bool(self.parent and self is self.parent.right)
108+
109+
def __iter__(self) -> Iterator[int]:
110+
if self.left:
111+
yield from self.left
112+
yield self.value
113+
if self.right:
114+
yield from self.right
14115

15116
def __repr__(self) -> str:
16117
if self.left is None and self.right is None:
@@ -22,9 +123,21 @@ def __repr__(self) -> str:
22123
class BinarySearchTree:
23124
root: Node | None = None
24125

126+
def __bool__(self) -> bool:
127+
return self.root is not None
128+
129+
def __iter__(self) -> Iterator[int]:
130+
if self.root:
131+
yield from self.root
132+
return iter(())
133+
134+
def __str__(self) -> str:
135+
return str(self.root) if self.root else "Empty tree"
136+
25137
def __reassign_nodes(self, node: Node, new_children: Node | None) -> None:
26138
if new_children is not None:
27139
new_children.parent = node.parent
140+
28141
if node.parent is not None:
29142
if node.is_right:
30143
node.parent.right = new_children
@@ -33,38 +146,71 @@ def __reassign_nodes(self, node: Node, new_children: Node | None) -> None:
33146
else:
34147
self.root = new_children
35148

36-
def __insert(self, value) -> None:
149+
def empty(self) -> bool:
150+
return self.root is None
151+
152+
def __insert(self, value: int) -> None:
37153
new_node = Node(value)
38154
if self.empty():
39155
self.root = new_node
40-
else:
41-
parent_node = self.root
42-
while True:
43-
if value < parent_node.value:
44-
if parent_node.left is None:
45-
parent_node.left = new_node
46-
break
47-
else:
48-
parent_node = parent_node.left
49-
elif parent_node.right is None:
156+
return
157+
158+
parent_node = self.root
159+
while parent_node:
160+
if value < parent_node.value:
161+
if parent_node.left is None:
162+
parent_node.left = new_node
163+
new_node.parent = parent_node
164+
return
165+
parent_node = parent_node.left
166+
else:
167+
if parent_node.right is None:
50168
parent_node.right = new_node
51-
break
52-
else:
53-
parent_node = parent_node.right
54-
new_node.parent = parent_node
169+
new_node.parent = parent_node
170+
return
171+
parent_node = parent_node.right
55172

56-
def search(self, value) -> Node | None:
173+
def insert(self, *values: int) -> BinarySearchTree:
174+
for value in values:
175+
self.__insert(value)
176+
return self
177+
178+
def search(self, value: int) -> Node | None:
57179
if self.empty():
58180
raise IndexError("Warning: Tree is empty! please use another.")
181+
59182
node = self.root
60183
while node is not None and node.value != value:
61-
node = node.left if value < node.value else node.right
184+
if value < node.value:
185+
node = node.left
186+
else:
187+
node = node.right
188+
return node
189+
190+
def get_max(self, node: Node | None = None) -> Node | None:
191+
if node is None:
192+
node = self.root
193+
if node is None:
194+
return None
195+
196+
while node.right is not None:
197+
node = node.right
198+
return node
199+
def get_min(self, node: Node | None = None) -> Node | None:
200+
if node is None:
201+
node = self.root
202+
if node is None:
203+
return None
204+
205+
while node.left is not None:
206+
node = node.left
62207
return node
63208

64209
def remove(self, value: int) -> None:
65210
node = self.search(value)
66211
if node is None:
67-
raise ValueError(f"Value {value} not found")
212+
error_msg = f"Value {value} not found"
213+
raise ValueError(error_msg)
68214

69215
if node.left is None and node.right is None:
70216
self.__reassign_nodes(node, None)
@@ -74,23 +220,45 @@ def remove(self, value: int) -> None:
74220
self.__reassign_nodes(node, node.left)
75221
else:
76222
predecessor = self.get_max(node.left)
77-
self.remove(predecessor.value)
78-
node.value = predecessor.value
223+
if predecessor is not None:
224+
self.remove(predecessor.value)
225+
node.value = predecessor.value
226+
227+
def preorder_traverse(self, node: Node | None) -> Iterator[Node]:
228+
if node is not None:
229+
yield node
230+
yield from self.preorder_traverse(node.left)
231+
yield from self.preorder_traverse(node.right)
232+
233+
def traversal_tree(self, traversal_function=None) -> Iterator[Node]:
234+
if traversal_function is None:
235+
return self.preorder_traverse(self.root)
236+
return traversal_function(self.root)
237+
238+
def inorder(self, arr: list[int], node: Node | None) -> None:
239+
if node:
240+
self.inorder(arr, node.left)
241+
arr.append(node.value)
242+
self.inorder(arr, node.right)
243+
244+
def find_kth_smallest(self, k: int, node: Node) -> int:
245+
arr: list[int] = []
246+
self.inorder(arr, node)
247+
return arr[k - 1]
79248

80249

81-
# 修复的递归函数
82250
def inorder(curr_node: Node | None) -> list[Node]:
83251
"""Inorder traversal (left, self, right)"""
84252
if curr_node is None:
85253
return []
86-
return inorder(curr_node.left) + [curr_node] + inorder(curr_node.right)
254+
return [*inorder(curr_node.left), curr_node, *inorder(curr_node.right)]
87255

88256

89257
def postorder(curr_node: Node | None) -> list[Node]:
90258
"""Postorder traversal (left, right, self)"""
91259
if curr_node is None:
92260
return []
93-
return postorder(curr_node.left) + postorder(curr_node.right) + [curr_node]
261+
return [*postorder(curr_node.left), *postorder(curr_node.right), curr_node]
94262

95263

96264
if __name__ == "__main__":

0 commit comments

Comments
 (0)