Skip to content

Commit 954ae88

Browse files
authored
Update avl_tree.py
1 parent 4930f1b commit 954ae88

File tree

1 file changed

+68
-41
lines changed

1 file changed

+68
-41
lines changed

data_structures/binary_tree/avl_tree.py

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
class MyQueue:
77
__slots__ = ("data", "head", "tail")
8-
8+
99
def __init__(self) -> None:
1010
self.data: list[Any] = []
1111
self.head = self.tail = 0
@@ -23,40 +23,63 @@ def pop(self) -> Any:
2323
return ret
2424

2525
class MyNode:
26-
__slots__ = ("data", "height", "left", "right") # 按字母顺序排序
27-
26+
__slots__ = ("data", "height", "left", "right")
27+
2828
def __init__(self, data: Any) -> None:
2929
self.data = data
3030
self.height = 1
3131
self.left: MyNode | None = None
3232
self.right: MyNode | None = None
3333

34-
def get_height(node: MyNode | None) -> int:
34+
def get_height(node: MyNode | None) -> int:
3535
return node.height if node else 0
3636

37-
def my_max(a: int, b: int) -> int:
37+
def my_max(a: int, b: int) -> int:
3838
return a if a > b else b
3939

4040
def right_rotation(node: MyNode) -> MyNode:
4141
left_child = node.left
4242
if left_child is None:
4343
return node
44-
44+
4545
node.left = left_child.right
4646
left_child.right = node
47-
node.height = my_max(get_height(node.right), get_height(node.left)) + 1
48-
left_child.height = my_max(get_height(left_child.right), get_height(left_child.left)) + 1
47+
48+
# 拆分长表达式
49+
node_height = my_max(
50+
get_height(node.right),
51+
get_height(node.left)
52+
) + 1
53+
node.height = node_height
54+
55+
left_height = my_max(
56+
get_height(left_child.right),
57+
get_height(left_child.left)
58+
) + 1
59+
left_child.height = left_height
60+
4961
return left_child
50-
5162
def left_rotation(node: MyNode) -> MyNode:
5263
right_child = node.right
5364
if right_child is None:
5465
return node
55-
66+
5667
node.right = right_child.left
5768
right_child.left = node
58-
node.height = my_max(get_height(node.right), get_height(node.left)) + 1
59-
right_child.height = my_max(get_height(right_child.right), get_height(right_child.left)) + 1
69+
70+
# 拆分长表达式
71+
node_height = my_max(
72+
get_height(node.right),
73+
get_height(node.left)
74+
) + 1
75+
node.height = node_height
76+
77+
right_height = my_max(
78+
get_height(right_child.right),
79+
get_height(right_child.left)
80+
) + 1
81+
right_child.height = right_height
82+
6083
return right_child
6184

6285
def lr_rotation(node: MyNode) -> MyNode:
@@ -72,7 +95,7 @@ def rl_rotation(node: MyNode) -> MyNode:
7295
def insert_node(node: MyNode | None, data: Any) -> MyNode | None:
7396
if node is None:
7497
return MyNode(data)
75-
98+
7699
if data < node.data:
77100
node.left = insert_node(node.left, data)
78101
if get_height(node.left) - get_height(node.right) == 2:
@@ -87,8 +110,11 @@ def insert_node(node: MyNode | None, data: Any) -> MyNode | None:
87110
node = rl_rotation(node)
88111
else:
89112
node = left_rotation(node)
90-
91-
node.height = my_max(get_height(node.right), get_height(node.left)) + 1
113+
114+
node.height = my_max(
115+
get_height(node.right),
116+
get_height(node.left)
117+
) + 1
92118
return node
93119

94120
def get_left_most(root: MyNode) -> Any:
@@ -99,7 +125,8 @@ def get_left_most(root: MyNode) -> Any:
99125
def del_node(root: MyNode | None, data: Any) -> MyNode | None:
100126
if root is None:
101127
return None
102-
if data == root.data:
128+
129+
if data == root.data:
103130
if root.left and root.right:
104131
root.data = get_left_most(root.right)
105132
root.right = del_node(root.right, root.data)
@@ -109,82 +136,82 @@ def del_node(root: MyNode | None, data: Any) -> MyNode | None:
109136
root.left = del_node(root.left, data)
110137
else:
111138
root.right = del_node(root.right, data)
112-
139+
113140
if root.left is None and root.right is None:
114141
root.height = 1
115142
return root
116-
143+
117144
left_height = get_height(root.left)
118145
right_height = get_height(root.right)
119-
146+
120147
if right_height - left_height == 2:
121148
right_right = get_height(root.right.right) if root.right else 0
122149
right_left = get_height(root.right.left) if root.right else 0
123-
# 使用三元表达式
124150
root = left_rotation(root) if right_right > right_left else rl_rotation(root)
125151
elif left_height - right_height == 2:
126152
left_left = get_height(root.left.left) if root.left else 0
127153
left_right = get_height(root.left.right) if root.left else 0
128-
# 使用三元表达式
129154
root = right_rotation(root) if left_left > left_right else lr_rotation(root)
130-
131-
root.height = my_max(get_height(root.right), get_height(root.left)) + 1
155+
156+
root.height = my_max(
157+
get_height(root.right),
158+
get_height(root.left)
159+
) + 1
132160
return root
133-
134161
class AVLTree:
135162
__slots__ = ("root",)
136-
137-
def __init__(self) -> None:
163+
164+
def __init__(self) -> None:
138165
self.root: MyNode | None = None
139-
140-
def get_height(self) -> int:
166+
167+
def get_height(self) -> int:
141168
return get_height(self.root)
142-
169+
143170
def insert(self, data: Any) -> None:
144171
self.root = insert_node(self.root, data)
145-
172+
146173
def delete(self, data: Any) -> None:
147174
self.root = del_node(self.root, data)
148-
175+
149176
def __str__(self) -> str:
150177
if self.root is None:
151178
return ""
152-
179+
153180
levels = []
154-
# 明确指定类型为 MyNode | None
155181
queue: list[MyNode | None] = [self.root]
156-
182+
157183
while queue:
158184
current = []
159185
next_level: list[MyNode | None] = []
160-
186+
161187
for node in queue:
162188
if node:
163189
current.append(str(node.data))
164190
next_level.append(node.left)
165191
next_level.append(node.right)
166192
else:
167193
current.append("*")
168-
next_level.extend([None, None])
169-
194+
next_level.append(None)
195+
next_level.append(None)
196+
170197
if any(node is not None for node in next_level):
171198
levels.append(" ".join(current))
172199
queue = next_level
173200
else:
174-
if current: # 添加最后一行
201+
if current:
175202
levels.append(" ".join(current))
176203
break
177-
204+
178205
return "\n".join(levels) + "\n" + "*"*36
179206

180207
def test_avl_tree() -> None:
181208
t = AVLTree()
182209
lst = list(range(10))
183210
random.shuffle(lst)
184-
211+
185212
for i in lst:
186213
t.insert(i)
187-
214+
188215
random.shuffle(lst)
189216
for i in lst:
190217
t.delete(i)

0 commit comments

Comments
 (0)