Skip to content

Commit 8fb7946

Browse files
authored
Update avl_tree.py
1 parent 2708486 commit 8fb7946

File tree

1 file changed

+97
-45
lines changed

1 file changed

+97
-45
lines changed

data_structures/binary_tree/avl_tree.py

Lines changed: 97 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
"""
2+
Auto-balanced binary tree implementation
3+
For doctests: python3 -m doctest -v avl_tree.py
4+
For testing: python avl_tree.py
5+
"""
6+
17
from __future__ import annotations
28

39
import math
@@ -28,19 +34,34 @@ def __init__(self, data: Any) -> None:
2834
self.left = self.right = None
2935
self.height = 1
3036

31-
def get_data(self) -> Any: return self.data
32-
def get_left(self) -> MyNode | None: return self.left
33-
def get_right(self) -> MyNode | None: return self.right
34-
def get_height(self) -> int: return self.height
35-
def set_data(self, data: Any) -> None: self.data = data
36-
def set_left(self, node: MyNode | None) -> None: self.left = node
37-
def set_right(self, node: MyNode | None) -> None: self.right = node
38-
def set_height(self, height: int) -> None: self.height = height
39-
40-
def get_height(node: MyNode | None) -> int:
37+
def get_data(self) -> Any:
38+
return self.data
39+
40+
def get_left(self) -> MyNode | None:
41+
return self.left
42+
43+
def get_right(self) -> MyNode | None:
44+
return self.right
45+
46+
def get_height(self) -> int:
47+
return self.height
48+
49+
def set_data(self, data: Any) -> None:
50+
self.data = data
51+
52+
def set_left(self, node: MyNode | None) -> None:
53+
self.left = node
54+
55+
def set_right(self, node: MyNode | None) -> None:
56+
self.right = node
57+
58+
def set_height(self, height: int) -> None:
59+
self.height = height
60+
61+
def get_height(node: MyNode | None) -> int:
4162
return node.height if node else 0
4263

43-
def my_max(a: int, b: int) -> int:
64+
def my_max(a: int, b: int) -> int:
4465
return a if a > b else b
4566

4667
def right_rotation(node: MyNode) -> MyNode:
@@ -51,7 +72,8 @@ def right_rotation(node: MyNode) -> MyNode:
5172
node.height = my_max(get_height(node.right), get_height(node.left)) + 1
5273
ret.height = my_max(get_height(ret.right), get_height(ret.left)) + 1
5374
return ret
54-
def left_rotation(node: MyNode) -> MyNode:
75+
76+
def left_rotation(node: MyNode) -> MyNode:
5577
print("right rotation node:", node.data)
5678
ret = node.right
5779
node.right = ret.left
@@ -69,17 +91,24 @@ def rl_rotation(node: MyNode) -> MyNode:
6991
return left_rotation(node)
7092

7193
def insert_node(node: MyNode | None, data: Any) -> MyNode | None:
72-
if not node: return MyNode(data)
73-
94+
if not node:
95+
return MyNode(data)
96+
7497
if data < node.data:
7598
node.left = insert_node(node.left, data)
7699
if get_height(node.left) - get_height(node.right) == 2:
77-
node = right_rotation(node) if data < node.left.data else lr_rotation(node)
100+
if data < node.left.data:
101+
node = right_rotation(node)
102+
else:
103+
node = lr_rotation(node)
78104
else:
79105
node.right = insert_node(node.right, data)
80106
if get_height(node.right) - get_height(node.left) == 2:
81-
node = rl_rotation(node) if data < node.right.data else left_rotation(node)
82-
107+
if data < node.right.data:
108+
node = rl_rotation(node)
109+
else:
110+
node = left_rotation(node)
111+
83112
node.height = my_max(get_height(node.right), get_height(node.left)) + 1
84113
return node
85114

@@ -93,68 +122,91 @@ def del_node(root: MyNode, data: Any) -> MyNode | None:
93122
if root.left and root.right:
94123
root.data = get_extreme(root.right, False)
95124
root.right = del_node(root.right, root.data)
96-
else: return root.left or root.right
125+
else:
126+
return root.left or root.right
97127
elif root.data > data:
98-
if not root.left: return root
128+
if not root.left:
129+
return root
99130
root.left = del_node(root.left, data)
100-
else: root.right = del_node(root.right, data)
101-
102-
if get_height(root.right) - get_height(root.left) == 2:
103-
root = left_rotation(root) if get_height(root.right.right) > get_height(root.right.left) else rl_rotation(root)
104-
elif get_height(root.right) - get_height(root.left) == -2:
105-
root = right_rotation(root) if get_height(root.left.left) > get_height(root.left.right) else lr_rotation(root)
106-
131+
else:
132+
root.right = del_node(root.right, data)
133+
134+
# Handle balancing
135+
right_height = get_height(root.right)
136+
left_height = get_height(root.left)
137+
138+
if right_height - left_height == 2:
139+
if get_height(root.right.right) > get_height(root.right.left):
140+
root = left_rotation(root)
141+
else:
142+
root = rl_rotation(root)
143+
elif right_height - left_height == -2:
144+
if get_height(root.left.left) > get_height(root.left.right):
145+
root = right_rotation(root)
146+
else:
147+
root = lr_rotation(root)
148+
107149
root.height = my_max(get_height(root.right), get_height(root.left)) + 1
108150
return root
109-
class AVLtree:
110-
def __init__(self) -> None: self.root = None
111-
def get_height(self) -> int: return get_height(self.root)
112-
151+
class AVLtree:
152+
def __init__(self) -> None:
153+
self.root = None
154+
155+
def get_height(self) -> int:
156+
return get_height(self.root)
157+
113158
def insert(self, data: Any) -> None:
114159
print(f"insert:{data}")
115160
self.root = insert_node(self.root, data)
116-
161+
117162
def del_node(self, data: Any) -> None:
118163
print(f"delete:{data}")
119-
if not self.root: return
164+
if not self.root:
165+
return
120166
self.root = del_node(self.root, data)
121-
167+
122168
def __str__(self) -> str:
123-
if not self.root: return ""
169+
if not self.root:
170+
return ""
124171
q, output, layer, cnt = MyQueue(), "", self.get_height(), 0
125172
q.push(self.root)
126-
173+
127174
while not q.is_empty():
128175
node = q.pop()
129176
space = " " * int(2**(layer-1))
130177
output += space + (str(node.data) if node else "*") + space
131178
cnt += 1
132-
179+
133180
if node:
134181
q.push(node.left)
135182
q.push(node.right)
136183
else:
137184
q.push(None)
138185
q.push(None)
139-
140-
if any(cnt == 2**i - 1 for i in range(10)):
141-
layer -= 1
142-
output += "\n"
143-
if layer == 0: break
144-
186+
187+
for i in range(10):
188+
if cnt == 2**i - 1:
189+
layer -= 1
190+
output += "\n"
191+
if layer == 0:
192+
break
193+
break
194+
145195
return output + "\n" + "*"*36
146-
def _test() -> None: doctest.testmod()
196+
197+
def _test() -> None:
198+
doctest.testmod()
147199

148200
if __name__ == "__main__":
149201
_test()
150202
t = AVLtree()
151203
lst = list(range(10))
152204
random.shuffle(lst)
153-
205+
154206
for i in lst:
155207
t.insert(i)
156208
print(t)
157-
209+
158210
random.shuffle(lst)
159211
for i in lst:
160212
t.del_node(i)

0 commit comments

Comments
 (0)