33import random
44from typing import Any
55
6+
67class MyQueue :
78 __slots__ = ("data" , "head" , "tail" )
8-
9+
910 def __init__ (self ) -> None :
1011 self .data : list [Any ] = []
1112 self .head = self .tail = 0
@@ -22,80 +23,79 @@ def pop(self) -> Any:
2223 self .head += 1
2324 return ret
2425
26+
2527class MyNode :
2628 __slots__ = ("data" , "height" , "left" , "right" )
27-
29+
2830 def __init__ (self , data : Any ) -> None :
2931 self .data = data
3032 self .height = 1
3133 self .left : MyNode | None = None
3234 self .right : MyNode | None = None
3335
34- def get_height (node : MyNode | None ) -> int :
36+
37+ def get_height (node : MyNode | None ) -> int :
3538 return node .height if node else 0
3639
37- def my_max (a : int , b : int ) -> int :
40+
41+ def my_max (a : int , b : int ) -> int :
3842 return a if a > b else b
3943
44+
4045def right_rotation (node : MyNode ) -> MyNode :
4146 left_child = node .left
4247 if left_child is None :
4348 return node
44-
49+
4550 node .left = left_child .right
4651 left_child .right = node
47-
52+
4853 # 拆分长表达式
49- node_height = my_max (
50- get_height (node .right ),
51- get_height (node .left )
52- ) + 1
54+ node_height = my_max (get_height (node .right ), get_height (node .left )) + 1
5355 node .height = node_height
54-
55- left_height = my_max (
56- get_height (left_child .right ),
57- get_height (left_child .left )
58- ) + 1
56+
57+ left_height = my_max (get_height (left_child .right ), get_height (left_child .left )) + 1
5958 left_child .height = left_height
60-
59+
6160 return left_child
61+
62+
6263def left_rotation (node : MyNode ) -> MyNode :
6364 right_child = node .right
6465 if right_child is None :
6566 return node
66-
67+
6768 node .right = right_child .left
6869 right_child .left = node
69-
70+
7071 # 拆分长表达式
71- node_height = my_max (
72- get_height (node .right ),
73- get_height (node .left )
74- ) + 1
72+ node_height = my_max (get_height (node .right ), get_height (node .left )) + 1
7573 node .height = node_height
76-
77- right_height = my_max (
78- get_height (right_child .right ),
79- get_height (right_child .left )
80- ) + 1
74+
75+ right_height = (
76+ my_max (get_height (right_child .right ), get_height (right_child .left )) + 1
77+ )
8178 right_child .height = right_height
82-
79+
8380 return right_child
8481
82+
8583def lr_rotation (node : MyNode ) -> MyNode :
8684 if node .left :
8785 node .left = left_rotation (node .left )
8886 return right_rotation (node )
8987
88+
9089def rl_rotation (node : MyNode ) -> MyNode :
9190 if node .right :
9291 node .right = right_rotation (node .right )
9392 return left_rotation (node )
9493
94+
9595def insert_node (node : MyNode | None , data : Any ) -> MyNode | None :
9696 if node is None :
9797 return MyNode (data )
98-
98+
9999 if data < node .data :
100100 node .left = insert_node (node .left , data )
101101 if get_height (node .left ) - get_height (node .right ) == 2 :
@@ -110,22 +110,21 @@ def insert_node(node: MyNode | None, data: Any) -> MyNode | None:
110110 node = rl_rotation (node )
111111 else :
112112 node = left_rotation (node )
113-
114- node .height = my_max (
115- get_height (node .right ),
116- get_height (node .left )
117- ) + 1
113+
114+ node .height = my_max (get_height (node .right ), get_height (node .left )) + 1
118115 return node
119116
117+
120118def get_left_most (root : MyNode ) -> Any :
121119 while root .left :
122120 root = root .left
123121 return root .data
124122
123+
125124def del_node (root : MyNode | None , data : Any ) -> MyNode | None :
126125 if root is None :
127126 return None
128-
127+
129128 if data == root .data :
130129 if root .left and root .right :
131130 root .data = get_left_most (root .right )
@@ -136,14 +135,14 @@ def del_node(root: MyNode | None, data: Any) -> MyNode | None:
136135 root .left = del_node (root .left , data )
137136 else :
138137 root .right = del_node (root .right , data )
139-
138+
140139 if root .left is None and root .right is None :
141140 root .height = 1
142141 return root
143-
142+
144143 left_height = get_height (root .left )
145144 right_height = get_height (root .right )
146-
145+
147146 if right_height - left_height == 2 :
148147 right_right = get_height (root .right .right ) if root .right else 0
149148 right_left = get_height (root .right .left ) if root .right else 0
@@ -152,38 +151,37 @@ def del_node(root: MyNode | None, data: Any) -> MyNode | None:
152151 left_left = get_height (root .left .left ) if root .left else 0
153152 left_right = get_height (root .left .right ) if root .left else 0
154153 root = right_rotation (root ) if left_left > left_right else lr_rotation (root )
155-
156- root .height = my_max (
157- get_height (root .right ),
158- get_height (root .left )
159- ) + 1
154+
155+ root .height = my_max (get_height (root .right ), get_height (root .left )) + 1
160156 return root
157+
158+
161159class AVLTree :
162160 __slots__ = ("root" ,)
163-
164- def __init__ (self ) -> None :
161+
162+ def __init__ (self ) -> None :
165163 self .root : MyNode | None = None
166-
167- def get_height (self ) -> int :
164+
165+ def get_height (self ) -> int :
168166 return get_height (self .root )
169-
167+
170168 def insert (self , data : Any ) -> None :
171169 self .root = insert_node (self .root , data )
172-
170+
173171 def delete (self , data : Any ) -> None :
174172 self .root = del_node (self .root , data )
175-
173+
176174 def __str__ (self ) -> str :
177175 if self .root is None :
178176 return ""
179-
177+
180178 levels = []
181179 queue : list [MyNode | None ] = [self .root ]
182-
180+
183181 while queue :
184182 current = []
185183 next_level : list [MyNode | None ] = []
186-
184+
187185 for node in queue :
188186 if node :
189187 current .append (str (node .data ))
@@ -193,28 +191,30 @@ def __str__(self) -> str:
193191 current .append ("*" )
194192 next_level .append (None )
195193 next_level .append (None )
196-
194+
197195 if any (node is not None for node in next_level ):
198196 levels .append (" " .join (current ))
199197 queue = next_level
200198 else :
201199 if current :
202200 levels .append (" " .join (current ))
203201 break
204-
205- return "\n " .join (levels ) + "\n " + "*" * 36
202+
203+ return "\n " .join (levels ) + "\n " + "*" * 36
204+
206205
207206def test_avl_tree () -> None :
208207 t = AVLTree ()
209208 lst = list (range (10 ))
210209 random .shuffle (lst )
211-
210+
212211 for i in lst :
213212 t .insert (i )
214-
213+
215214 random .shuffle (lst )
216215 for i in lst :
217216 t .delete (i )
218217
218+
219219if __name__ == "__main__" :
220220 test_avl_tree ()
0 commit comments