22
33from collections .abc import Iterator
44from pprint import pformat
5+ from typing import Optional , Tuple
56
67
78class RedBlackTree :
@@ -62,7 +63,6 @@ def rotate_left(self) -> RedBlackTree:
6263 parent .right = right
6364 right .parent = parent
6465 return right
65-
6666 def rotate_right (self ) -> RedBlackTree :
6767 """Rotate the subtree rooted at this node to the right and
6868 returns the new root to this subtree.
@@ -99,12 +99,14 @@ def insert(self, label: int) -> RedBlackTree:
9999 return self
100100 elif self .label > label :
101101 if self .left :
102- self .left .insert (label )
102+ # 更新:递归插入后更新左子树引用
103+ self .left = self .left .insert (label )
103104 else :
104105 self .left = RedBlackTree (label , 1 , self )
105106 self .left ._insert_repair ()
106107 elif self .right :
107- self .right .insert (label )
108+ # 更新:递归插入后更新右子树引用
109+ self .right = self .right .insert (label )
108110 else :
109111 self .right = RedBlackTree (label , 1 , self )
110112 self .right ._insert_repair ()
@@ -149,7 +151,7 @@ def _insert_repair(self) -> None:
149151 self .grandparent ._insert_repair ()
150152
151153 def remove (self , label : int ) -> RedBlackTree :
152- """Remove label from this tree."""
154+ """Remove label from this tree, returning the new root of the subtree ."""
153155 if self .label == label :
154156 if self .left and self .right :
155157 # It's easier to balance a node with at most one child,
@@ -158,7 +160,8 @@ def remove(self, label: int) -> RedBlackTree:
158160 value = self .left .get_max ()
159161 if value is not None :
160162 self .label = value
161- self .left .remove (value )
163+ # 更新:递归删除后更新左子树引用
164+ self .left = self .left .remove (value )
162165 else :
163166 # This node has at most one non-None child, so we don't
164167 # need to replace
@@ -198,9 +201,11 @@ def remove(self, label: int) -> RedBlackTree:
198201 self .right .parent = self
199202 elif self .label is not None and self .label > label :
200203 if self .left :
201- self .left .remove (label )
204+ # 更新:递归删除后更新左子树引用
205+ self .left = self .left .remove (label )
202206 elif self .right :
203- self .right .remove (label )
207+ # 更新:递归删除后更新右子树引用
208+ self .right = self .right .remove (label )
204209 return self .parent or self
205210
206211 def _remove_repair (self ) -> None :
@@ -308,7 +313,6 @@ def check_color_properties(self) -> bool:
308313 return False
309314 # All properties were met
310315 return True
311-
312316 def check_coloring (self ) -> bool :
313317 """A helper function to recursively check Property 4 of a
314318 Red-Black Tree. See check_color_properties for more info.
@@ -320,24 +324,23 @@ def check_coloring(self) -> bool:
320324 return not (self .right and not self .right .check_coloring ())
321325
322326 def black_height (self ) -> int | None :
323- """Returns the number of black nodes from this node to the
324- leaves of the tree, or None if there isn't one such value (the
325- tree is color incorrectly).
326- """
327- if self is None or self .left is None or self .right is None :
328- # If we're already at a leaf, there is no path
327+ """修正的黑色高度计算方法"""
328+ # 叶子节点(None)被视为黑色,高度为1
329+ if self is None :
329330 return 1
330- left = RedBlackTree .black_height (self .left )
331- right = RedBlackTree .black_height (self .right )
332- if left is None or right is None :
333- # There are issues with coloring below children nodes
331+
332+ # 递归计算左右子树高度
333+ left_bh = RedBlackTree .black_height (self .left )
334+ right_bh = RedBlackTree .black_height (self .right )
335+
336+ # 检查高度是否有效且一致
337+ if left_bh is None or right_bh is None :
334338 return None
335- if left != right :
336- # The two children have unequal depths
339+ if left_bh != right_bh :
337340 return None
338- # Return the black depth of children, plus one if this node is
339- # black
340- return left + (1 - self .color )
341+
342+ # 返回当前节点高度(黑色节点+1)
343+ return left_bh + (1 - self .color )
341344
342345 # Here are functions which are general to all binary search trees
343346
@@ -364,7 +367,6 @@ def search(self, label: int) -> RedBlackTree | None:
364367 return None
365368 else :
366369 return self .left .search (label )
367-
368370 def floor (self , label : int ) -> int | None :
369371 """Returns the largest element in this tree which is at most label.
370372 This method is guaranteed to run in O(log(n)) time."""
@@ -437,7 +439,6 @@ def sibling(self) -> RedBlackTree | None:
437439 return self .parent .right
438440 else :
439441 return self .parent .left
440-
441442 def is_left (self ) -> bool :
442443 """Returns true iff this node is the left child of its parent."""
443444 if self .parent is None :
@@ -451,12 +452,13 @@ def is_right(self) -> bool:
451452 return self .parent .right is self
452453
453454 def __bool__ (self ) -> bool :
454- return True
455+ """空树返回False"""
456+ return self .label is not None
455457
456458 def __len__ (self ) -> int :
457- """
458- Return the number of nodes in this tree.
459- """
459+ """正确处理空树情况"""
460+ if self . label is None :
461+ return 0
460462 ln = 1
461463 if self .left :
462464 ln += len (self .left )
@@ -484,7 +486,6 @@ def postorder_traverse(self) -> Iterator[int | None]:
484486 if self .right :
485487 yield from self .right .postorder_traverse ()
486488 yield self .label
487-
488489 def __repr__ (self ) -> str :
489490 if self .left is None and self .right is None :
490491 return f"'{ self .label } { (self .color and 'red' ) or 'blk' } '"
@@ -502,18 +503,23 @@ def __eq__(self, other: object) -> bool:
502503 """Test if two trees are equal."""
503504 if not isinstance (other , RedBlackTree ):
504505 return NotImplemented
505- if self .label == other .label :
506- return self .left == other .left and self .right == other .right
507- else :
506+
507+ # 处理空树比较
508+ if self .label is None and other .label is None :
509+ return True
510+ if self .label != other .label :
508511 return False
512+
513+ # 递归比较子树
514+ return (self .left == other .left ) and (self .right == other .right )
515+
516+ # 明确表示该类的实例不可哈希
517+ __hash__ = None
509518
510519
511520def color (node : RedBlackTree | None ) -> int :
512521 """Returns the color of a node, allowing for None leaves."""
513- if node is None :
514- return 0
515- else :
516- return node .color
522+ return 0 if node is None else node .color
517523
518524
519525"""
@@ -555,7 +561,6 @@ def test_rotations() -> bool:
555561 right_rot .right .right .right = RedBlackTree (20 , parent = right_rot .right .right )
556562 return tree == right_rot
557563
558-
559564def test_insertion_speed () -> bool :
560565 """Test that the tree balances inserts to O(log(n)) by doing a lot
561566 of them.
@@ -639,8 +644,6 @@ def test_floor_ceil() -> bool:
639644 if tree .floor (val ) != floor or tree .ceil (val ) != ceil :
640645 return False
641646 return True
642-
643-
644647def test_min_max () -> bool :
645648 """Tests the min and max functions in the tree."""
646649 tree = RedBlackTree (0 )
@@ -668,7 +671,6 @@ def test_tree_traversal() -> bool:
668671 return False
669672 return list (tree .postorder_traverse ()) == [- 16 , 8 , 20 , 24 , 22 , 16 , 0 ]
670673
671-
672674def test_tree_chaining () -> bool :
673675 """Tests the three different tree chaining functions."""
674676 tree = RedBlackTree (0 )
@@ -680,10 +682,37 @@ def test_tree_chaining() -> bool:
680682 return list (tree .postorder_traverse ()) == [- 16 , 8 , 20 , 24 , 22 , 16 , 0 ]
681683
682684
683- def print_results (msg : str , passes : bool ) -> None :
684- print (str (msg ), "works!" if passes else "doesn't work :(" )
685+ def test_empty_tree () -> bool :
686+ """Tests behavior with empty trees."""
687+ tree = RedBlackTree (None )
688+
689+ # 测试空树属性
690+ if tree .label is not None or tree .left or tree .right :
691+ return False
692+
693+ # 测试空树长度
694+ if len (tree ) != 0 :
695+ return False
696+
697+ # 测试空树布尔值
698+ if tree :
699+ return False
700+
701+ # 测试空树搜索
702+ if 0 in tree or tree .search (0 ):
703+ return False
704+
705+ # 测试空树删除
706+ try :
707+ tree .remove (0 )
708+ except Exception :
709+ return False
710+
711+ return True
685712
686713
714+ def print_results (msg : str , passes : bool ) -> None :
715+ print (str (msg ), "works!" if passes else "doesn't work :(" )
687716def pytests () -> None :
688717 assert test_rotations ()
689718 assert test_insert ()
@@ -692,6 +721,7 @@ def pytests() -> None:
692721 assert test_floor_ceil ()
693722 assert test_tree_traversal ()
694723 assert test_tree_chaining ()
724+ assert test_empty_tree ()
695725
696726
697727def main () -> None :
@@ -704,7 +734,8 @@ def main() -> None:
704734 print_results ("Deleting" , test_insert_delete ())
705735 print_results ("Floor and ceil" , test_floor_ceil ())
706736 print_results ("Tree traversal" , test_tree_traversal ())
707- print_results ("Tree traversal" , test_tree_chaining ())
737+ print_results ("Tree chaining" , test_tree_chaining ())
738+ print_results ("Empty tree handling" , test_empty_tree ())
708739 print ("Testing tree balancing..." )
709740 print ("This should only be a few seconds." )
710741 test_insertion_speed ()
0 commit comments