55
66class MyQueue :
77 __slots__ = ("data" , "head" , "tail" )
8-
8+
99 def __init__ (self ) -> None :
1010 self .data : list [Any ] = []
1111 self .head = self .tail = 0
@@ -23,40 +23,63 @@ def pop(self) -> Any:
2323 return ret
2424
2525class MyNode :
26- __slots__ = ("data" , "height" , "left" , "right" ) # 按字母顺序排序
27-
26+ __slots__ = ("data" , "height" , "left" , "right" )
27+
2828 def __init__ (self , data : Any ) -> None :
2929 self .data = data
3030 self .height = 1
3131 self .left : MyNode | None = None
3232 self .right : MyNode | None = None
3333
34- def get_height (node : MyNode | None ) -> int :
34+ def get_height (node : MyNode | None ) -> int :
3535 return node .height if node else 0
3636
37- def my_max (a : int , b : int ) -> int :
37+ def my_max (a : int , b : int ) -> int :
3838 return a if a > b else b
3939
4040def right_rotation (node : MyNode ) -> MyNode :
4141 left_child = node .left
4242 if left_child is None :
4343 return node
44-
44+
4545 node .left = left_child .right
4646 left_child .right = node
47- node .height = my_max (get_height (node .right ), get_height (node .left )) + 1
48- left_child .height = my_max (get_height (left_child .right ), get_height (left_child .left )) + 1
47+
48+ # 拆分长表达式
49+ node_height = my_max (
50+ get_height (node .right ),
51+ get_height (node .left )
52+ ) + 1
53+ node .height = node_height
54+
55+ left_height = my_max (
56+ get_height (left_child .right ),
57+ get_height (left_child .left )
58+ ) + 1
59+ left_child .height = left_height
60+
4961 return left_child
50-
5162def left_rotation (node : MyNode ) -> MyNode :
5263 right_child = node .right
5364 if right_child is None :
5465 return node
55-
66+
5667 node .right = right_child .left
5768 right_child .left = node
58- node .height = my_max (get_height (node .right ), get_height (node .left )) + 1
59- right_child .height = my_max (get_height (right_child .right ), get_height (right_child .left )) + 1
69+
70+ # 拆分长表达式
71+ node_height = my_max (
72+ get_height (node .right ),
73+ get_height (node .left )
74+ ) + 1
75+ node .height = node_height
76+
77+ right_height = my_max (
78+ get_height (right_child .right ),
79+ get_height (right_child .left )
80+ ) + 1
81+ right_child .height = right_height
82+
6083 return right_child
6184
6285def lr_rotation (node : MyNode ) -> MyNode :
@@ -72,7 +95,7 @@ def rl_rotation(node: MyNode) -> MyNode:
7295def insert_node (node : MyNode | None , data : Any ) -> MyNode | None :
7396 if node is None :
7497 return MyNode (data )
75-
98+
7699 if data < node .data :
77100 node .left = insert_node (node .left , data )
78101 if get_height (node .left ) - get_height (node .right ) == 2 :
@@ -87,8 +110,11 @@ def insert_node(node: MyNode | None, data: Any) -> MyNode | None:
87110 node = rl_rotation (node )
88111 else :
89112 node = left_rotation (node )
90-
91- node .height = my_max (get_height (node .right ), get_height (node .left )) + 1
113+
114+ node .height = my_max (
115+ get_height (node .right ),
116+ get_height (node .left )
117+ ) + 1
92118 return node
93119
94120def get_left_most (root : MyNode ) -> Any :
@@ -99,7 +125,8 @@ def get_left_most(root: MyNode) -> Any:
99125def del_node (root : MyNode | None , data : Any ) -> MyNode | None :
100126 if root is None :
101127 return None
102- if data == root .data :
128+
129+ if data == root .data :
103130 if root .left and root .right :
104131 root .data = get_left_most (root .right )
105132 root .right = del_node (root .right , root .data )
@@ -109,82 +136,82 @@ def del_node(root: MyNode | None, data: Any) -> MyNode | None:
109136 root .left = del_node (root .left , data )
110137 else :
111138 root .right = del_node (root .right , data )
112-
139+
113140 if root .left is None and root .right is None :
114141 root .height = 1
115142 return root
116-
143+
117144 left_height = get_height (root .left )
118145 right_height = get_height (root .right )
119-
146+
120147 if right_height - left_height == 2 :
121148 right_right = get_height (root .right .right ) if root .right else 0
122149 right_left = get_height (root .right .left ) if root .right else 0
123- # 使用三元表达式
124150 root = left_rotation (root ) if right_right > right_left else rl_rotation (root )
125151 elif left_height - right_height == 2 :
126152 left_left = get_height (root .left .left ) if root .left else 0
127153 left_right = get_height (root .left .right ) if root .left else 0
128- # 使用三元表达式
129154 root = right_rotation (root ) if left_left > left_right else lr_rotation (root )
130-
131- root .height = my_max (get_height (root .right ), get_height (root .left )) + 1
155+
156+ root .height = my_max (
157+ get_height (root .right ),
158+ get_height (root .left )
159+ ) + 1
132160 return root
133-
134161class AVLTree :
135162 __slots__ = ("root" ,)
136-
137- def __init__ (self ) -> None :
163+
164+ def __init__ (self ) -> None :
138165 self .root : MyNode | None = None
139-
140- def get_height (self ) -> int :
166+
167+ def get_height (self ) -> int :
141168 return get_height (self .root )
142-
169+
143170 def insert (self , data : Any ) -> None :
144171 self .root = insert_node (self .root , data )
145-
172+
146173 def delete (self , data : Any ) -> None :
147174 self .root = del_node (self .root , data )
148-
175+
149176 def __str__ (self ) -> str :
150177 if self .root is None :
151178 return ""
152-
179+
153180 levels = []
154- # 明确指定类型为 MyNode | None
155181 queue : list [MyNode | None ] = [self .root ]
156-
182+
157183 while queue :
158184 current = []
159185 next_level : list [MyNode | None ] = []
160-
186+
161187 for node in queue :
162188 if node :
163189 current .append (str (node .data ))
164190 next_level .append (node .left )
165191 next_level .append (node .right )
166192 else :
167193 current .append ("*" )
168- next_level .extend ([None , None ])
169-
194+ next_level .append (None )
195+ next_level .append (None )
196+
170197 if any (node is not None for node in next_level ):
171198 levels .append (" " .join (current ))
172199 queue = next_level
173200 else :
174- if current : # 添加最后一行
201+ if current :
175202 levels .append (" " .join (current ))
176203 break
177-
204+
178205 return "\n " .join (levels ) + "\n " + "*" * 36
179206
180207def test_avl_tree () -> None :
181208 t = AVLTree ()
182209 lst = list (range (10 ))
183210 random .shuffle (lst )
184-
211+
185212 for i in lst :
186213 t .insert (i )
187-
214+
188215 random .shuffle (lst )
189216 for i in lst :
190217 t .delete (i )
0 commit comments