1+ """
2+ Auto-balanced binary tree implementation
3+ For doctests: python3 -m doctest -v avl_tree.py
4+ For testing: python avl_tree.py
5+ """
6+
17from __future__ import annotations
28
39import math
@@ -28,19 +34,34 @@ def __init__(self, data: Any) -> None:
2834 self .left = self .right = None
2935 self .height = 1
3036
31- def get_data (self ) -> Any : return self .data
32- def get_left (self ) -> MyNode | None : return self .left
33- def get_right (self ) -> MyNode | None : return self .right
34- def get_height (self ) -> int : return self .height
35- def set_data (self , data : Any ) -> None : self .data = data
36- def set_left (self , node : MyNode | None ) -> None : self .left = node
37- def set_right (self , node : MyNode | None ) -> None : self .right = node
38- def set_height (self , height : int ) -> None : self .height = height
39-
40- def get_height (node : MyNode | None ) -> int :
37+ def get_data (self ) -> Any :
38+ return self .data
39+
40+ def get_left (self ) -> MyNode | None :
41+ return self .left
42+
43+ def get_right (self ) -> MyNode | None :
44+ return self .right
45+
46+ def get_height (self ) -> int :
47+ return self .height
48+
49+ def set_data (self , data : Any ) -> None :
50+ self .data = data
51+
52+ def set_left (self , node : MyNode | None ) -> None :
53+ self .left = node
54+
55+ def set_right (self , node : MyNode | None ) -> None :
56+ self .right = node
57+
58+ def set_height (self , height : int ) -> None :
59+ self .height = height
60+
61+ def get_height (node : MyNode | None ) -> int :
4162 return node .height if node else 0
4263
43- def my_max (a : int , b : int ) -> int :
64+ def my_max (a : int , b : int ) -> int :
4465 return a if a > b else b
4566
4667def right_rotation (node : MyNode ) -> MyNode :
@@ -51,7 +72,8 @@ def right_rotation(node: MyNode) -> MyNode:
5172 node .height = my_max (get_height (node .right ), get_height (node .left )) + 1
5273 ret .height = my_max (get_height (ret .right ), get_height (ret .left )) + 1
5374 return ret
54- def left_rotation (node : MyNode ) -> MyNode :
75+
76+ def left_rotation (node : MyNode ) -> MyNode :
5577 print ("right rotation node:" , node .data )
5678 ret = node .right
5779 node .right = ret .left
@@ -69,17 +91,24 @@ def rl_rotation(node: MyNode) -> MyNode:
6991 return left_rotation (node )
7092
7193def insert_node (node : MyNode | None , data : Any ) -> MyNode | None :
72- if not node : return MyNode (data )
73-
94+ if not node :
95+ return MyNode (data )
96+
7497 if data < node .data :
7598 node .left = insert_node (node .left , data )
7699 if get_height (node .left ) - get_height (node .right ) == 2 :
77- node = right_rotation (node ) if data < node .left .data else lr_rotation (node )
100+ if data < node .left .data :
101+ node = right_rotation (node )
102+ else :
103+ node = lr_rotation (node )
78104 else :
79105 node .right = insert_node (node .right , data )
80106 if get_height (node .right ) - get_height (node .left ) == 2 :
81- node = rl_rotation (node ) if data < node .right .data else left_rotation (node )
82-
107+ if data < node .right .data :
108+ node = rl_rotation (node )
109+ else :
110+ node = left_rotation (node )
111+
83112 node .height = my_max (get_height (node .right ), get_height (node .left )) + 1
84113 return node
85114
@@ -93,68 +122,91 @@ def del_node(root: MyNode, data: Any) -> MyNode | None:
93122 if root .left and root .right :
94123 root .data = get_extreme (root .right , False )
95124 root .right = del_node (root .right , root .data )
96- else : return root .left or root .right
125+ else :
126+ return root .left or root .right
97127 elif root .data > data :
98- if not root .left : return root
128+ if not root .left :
129+ return root
99130 root .left = del_node (root .left , data )
100- else : root .right = del_node (root .right , data )
101-
102- if get_height (root .right ) - get_height (root .left ) == 2 :
103- root = left_rotation (root ) if get_height (root .right .right ) > get_height (root .right .left ) else rl_rotation (root )
104- elif get_height (root .right ) - get_height (root .left ) == - 2 :
105- root = right_rotation (root ) if get_height (root .left .left ) > get_height (root .left .right ) else lr_rotation (root )
106-
131+ else :
132+ root .right = del_node (root .right , data )
133+
134+ # Handle balancing
135+ right_height = get_height (root .right )
136+ left_height = get_height (root .left )
137+
138+ if right_height - left_height == 2 :
139+ if get_height (root .right .right ) > get_height (root .right .left ):
140+ root = left_rotation (root )
141+ else :
142+ root = rl_rotation (root )
143+ elif right_height - left_height == - 2 :
144+ if get_height (root .left .left ) > get_height (root .left .right ):
145+ root = right_rotation (root )
146+ else :
147+ root = lr_rotation (root )
148+
107149 root .height = my_max (get_height (root .right ), get_height (root .left )) + 1
108150 return root
109- class AVLtree :
110- def __init__ (self ) -> None : self .root = None
111- def get_height (self ) -> int : return get_height (self .root )
112-
151+ class AVLtree :
152+ def __init__ (self ) -> None :
153+ self .root = None
154+
155+ def get_height (self ) -> int :
156+ return get_height (self .root )
157+
113158 def insert (self , data : Any ) -> None :
114159 print (f"insert:{ data } " )
115160 self .root = insert_node (self .root , data )
116-
161+
117162 def del_node (self , data : Any ) -> None :
118163 print (f"delete:{ data } " )
119- if not self .root : return
164+ if not self .root :
165+ return
120166 self .root = del_node (self .root , data )
121-
167+
122168 def __str__ (self ) -> str :
123- if not self .root : return ""
169+ if not self .root :
170+ return ""
124171 q , output , layer , cnt = MyQueue (), "" , self .get_height (), 0
125172 q .push (self .root )
126-
173+
127174 while not q .is_empty ():
128175 node = q .pop ()
129176 space = " " * int (2 ** (layer - 1 ))
130177 output += space + (str (node .data ) if node else "*" ) + space
131178 cnt += 1
132-
179+
133180 if node :
134181 q .push (node .left )
135182 q .push (node .right )
136183 else :
137184 q .push (None )
138185 q .push (None )
139-
140- if any (cnt == 2 ** i - 1 for i in range (10 )):
141- layer -= 1
142- output += "\n "
143- if layer == 0 : break
144-
186+
187+ for i in range (10 ):
188+ if cnt == 2 ** i - 1 :
189+ layer -= 1
190+ output += "\n "
191+ if layer == 0 :
192+ break
193+ break
194+
145195 return output + "\n " + "*" * 36
146- def _test () -> None : doctest .testmod ()
196+
197+ def _test () -> None :
198+ doctest .testmod ()
147199
148200if __name__ == "__main__" :
149201 _test ()
150202 t = AVLtree ()
151203 lst = list (range (10 ))
152204 random .shuffle (lst )
153-
205+
154206 for i in lst :
155207 t .insert (i )
156208 print (t )
157-
209+
158210 random .shuffle (lst )
159211 for i in lst :
160212 t .del_node (i )
0 commit comments