Skip to content

Commit b66c9e0

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 954ae88 commit b66c9e0

File tree

1 file changed

+58
-58
lines changed

1 file changed

+58
-58
lines changed

data_structures/binary_tree/avl_tree.py

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import random
44
from typing import Any
55

6+
67
class MyQueue:
78
__slots__ = ("data", "head", "tail")
8-
9+
910
def __init__(self) -> None:
1011
self.data: list[Any] = []
1112
self.head = self.tail = 0
@@ -22,80 +23,79 @@ def pop(self) -> Any:
2223
self.head += 1
2324
return ret
2425

26+
2527
class MyNode:
2628
__slots__ = ("data", "height", "left", "right")
27-
29+
2830
def __init__(self, data: Any) -> None:
2931
self.data = data
3032
self.height = 1
3133
self.left: MyNode | None = None
3234
self.right: MyNode | None = None
3335

34-
def get_height(node: MyNode | None) -> int:
36+
37+
def get_height(node: MyNode | None) -> int:
3538
return node.height if node else 0
3639

37-
def my_max(a: int, b: int) -> int:
40+
41+
def my_max(a: int, b: int) -> int:
3842
return a if a > b else b
3943

44+
4045
def right_rotation(node: MyNode) -> MyNode:
4146
left_child = node.left
4247
if left_child is None:
4348
return node
44-
49+
4550
node.left = left_child.right
4651
left_child.right = node
47-
52+
4853
# 拆分长表达式
49-
node_height = my_max(
50-
get_height(node.right),
51-
get_height(node.left)
52-
) + 1
54+
node_height = my_max(get_height(node.right), get_height(node.left)) + 1
5355
node.height = node_height
54-
55-
left_height = my_max(
56-
get_height(left_child.right),
57-
get_height(left_child.left)
58-
) + 1
56+
57+
left_height = my_max(get_height(left_child.right), get_height(left_child.left)) + 1
5958
left_child.height = left_height
60-
59+
6160
return left_child
61+
62+
6263
def left_rotation(node: MyNode) -> MyNode:
6364
right_child = node.right
6465
if right_child is None:
6566
return node
66-
67+
6768
node.right = right_child.left
6869
right_child.left = node
69-
70+
7071
# 拆分长表达式
71-
node_height = my_max(
72-
get_height(node.right),
73-
get_height(node.left)
74-
) + 1
72+
node_height = my_max(get_height(node.right), get_height(node.left)) + 1
7573
node.height = node_height
76-
77-
right_height = my_max(
78-
get_height(right_child.right),
79-
get_height(right_child.left)
80-
) + 1
74+
75+
right_height = (
76+
my_max(get_height(right_child.right), get_height(right_child.left)) + 1
77+
)
8178
right_child.height = right_height
82-
79+
8380
return right_child
8481

82+
8583
def lr_rotation(node: MyNode) -> MyNode:
8684
if node.left:
8785
node.left = left_rotation(node.left)
8886
return right_rotation(node)
8987

88+
9089
def rl_rotation(node: MyNode) -> MyNode:
9190
if node.right:
9291
node.right = right_rotation(node.right)
9392
return left_rotation(node)
9493

94+
9595
def insert_node(node: MyNode | None, data: Any) -> MyNode | None:
9696
if node is None:
9797
return MyNode(data)
98-
98+
9999
if data < node.data:
100100
node.left = insert_node(node.left, data)
101101
if get_height(node.left) - get_height(node.right) == 2:
@@ -110,22 +110,21 @@ def insert_node(node: MyNode | None, data: Any) -> MyNode | None:
110110
node = rl_rotation(node)
111111
else:
112112
node = left_rotation(node)
113-
114-
node.height = my_max(
115-
get_height(node.right),
116-
get_height(node.left)
117-
) + 1
113+
114+
node.height = my_max(get_height(node.right), get_height(node.left)) + 1
118115
return node
119116

117+
120118
def get_left_most(root: MyNode) -> Any:
121119
while root.left:
122120
root = root.left
123121
return root.data
124122

123+
125124
def del_node(root: MyNode | None, data: Any) -> MyNode | None:
126125
if root is None:
127126
return None
128-
127+
129128
if data == root.data:
130129
if root.left and root.right:
131130
root.data = get_left_most(root.right)
@@ -136,14 +135,14 @@ def del_node(root: MyNode | None, data: Any) -> MyNode | None:
136135
root.left = del_node(root.left, data)
137136
else:
138137
root.right = del_node(root.right, data)
139-
138+
140139
if root.left is None and root.right is None:
141140
root.height = 1
142141
return root
143-
142+
144143
left_height = get_height(root.left)
145144
right_height = get_height(root.right)
146-
145+
147146
if right_height - left_height == 2:
148147
right_right = get_height(root.right.right) if root.right else 0
149148
right_left = get_height(root.right.left) if root.right else 0
@@ -152,38 +151,37 @@ def del_node(root: MyNode | None, data: Any) -> MyNode | None:
152151
left_left = get_height(root.left.left) if root.left else 0
153152
left_right = get_height(root.left.right) if root.left else 0
154153
root = right_rotation(root) if left_left > left_right else lr_rotation(root)
155-
156-
root.height = my_max(
157-
get_height(root.right),
158-
get_height(root.left)
159-
) + 1
154+
155+
root.height = my_max(get_height(root.right), get_height(root.left)) + 1
160156
return root
157+
158+
161159
class AVLTree:
162160
__slots__ = ("root",)
163-
164-
def __init__(self) -> None:
161+
162+
def __init__(self) -> None:
165163
self.root: MyNode | None = None
166-
167-
def get_height(self) -> int:
164+
165+
def get_height(self) -> int:
168166
return get_height(self.root)
169-
167+
170168
def insert(self, data: Any) -> None:
171169
self.root = insert_node(self.root, data)
172-
170+
173171
def delete(self, data: Any) -> None:
174172
self.root = del_node(self.root, data)
175-
173+
176174
def __str__(self) -> str:
177175
if self.root is None:
178176
return ""
179-
177+
180178
levels = []
181179
queue: list[MyNode | None] = [self.root]
182-
180+
183181
while queue:
184182
current = []
185183
next_level: list[MyNode | None] = []
186-
184+
187185
for node in queue:
188186
if node:
189187
current.append(str(node.data))
@@ -193,28 +191,30 @@ def __str__(self) -> str:
193191
current.append("*")
194192
next_level.append(None)
195193
next_level.append(None)
196-
194+
197195
if any(node is not None for node in next_level):
198196
levels.append(" ".join(current))
199197
queue = next_level
200198
else:
201199
if current:
202200
levels.append(" ".join(current))
203201
break
204-
205-
return "\n".join(levels) + "\n" + "*"*36
202+
203+
return "\n".join(levels) + "\n" + "*" * 36
204+
206205

207206
def test_avl_tree() -> None:
208207
t = AVLTree()
209208
lst = list(range(10))
210209
random.shuffle(lst)
211-
210+
212211
for i in lst:
213212
t.insert(i)
214-
213+
215214
random.shuffle(lst)
216215
for i in lst:
217216
t.delete(i)
218217

218+
219219
if __name__ == "__main__":
220220
test_avl_tree()

0 commit comments

Comments
 (0)