Skip to content

Commit 4c3fa76

Browse files
authored
Update lru_cache.py
1 parent e5d9bd4 commit 4c3fa76

File tree

1 file changed

+75
-44
lines changed

1 file changed

+75
-44
lines changed

other/lru_cache.py

Lines changed: 75 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from collections.abc import Iterator
44
from pprint import pformat
5+
from typing import Optional, Tuple
56

67

78
class RedBlackTree:
@@ -62,7 +63,6 @@ def rotate_left(self) -> RedBlackTree:
6263
parent.right = right
6364
right.parent = parent
6465
return right
65-
6666
def rotate_right(self) -> RedBlackTree:
6767
"""Rotate the subtree rooted at this node to the right and
6868
returns the new root to this subtree.
@@ -99,12 +99,14 @@ def insert(self, label: int) -> RedBlackTree:
9999
return self
100100
elif self.label > label:
101101
if self.left:
102-
self.left.insert(label)
102+
# 更新:递归插入后更新左子树引用
103+
self.left = self.left.insert(label)
103104
else:
104105
self.left = RedBlackTree(label, 1, self)
105106
self.left._insert_repair()
106107
elif self.right:
107-
self.right.insert(label)
108+
# 更新:递归插入后更新右子树引用
109+
self.right = self.right.insert(label)
108110
else:
109111
self.right = RedBlackTree(label, 1, self)
110112
self.right._insert_repair()
@@ -149,7 +151,7 @@ def _insert_repair(self) -> None:
149151
self.grandparent._insert_repair()
150152

151153
def remove(self, label: int) -> RedBlackTree:
152-
"""Remove label from this tree."""
154+
"""Remove label from this tree, returning the new root of the subtree."""
153155
if self.label == label:
154156
if self.left and self.right:
155157
# It's easier to balance a node with at most one child,
@@ -158,7 +160,8 @@ def remove(self, label: int) -> RedBlackTree:
158160
value = self.left.get_max()
159161
if value is not None:
160162
self.label = value
161-
self.left.remove(value)
163+
# 更新:递归删除后更新左子树引用
164+
self.left = self.left.remove(value)
162165
else:
163166
# This node has at most one non-None child, so we don't
164167
# need to replace
@@ -198,9 +201,11 @@ def remove(self, label: int) -> RedBlackTree:
198201
self.right.parent = self
199202
elif self.label is not None and self.label > label:
200203
if self.left:
201-
self.left.remove(label)
204+
# 更新:递归删除后更新左子树引用
205+
self.left = self.left.remove(label)
202206
elif self.right:
203-
self.right.remove(label)
207+
# 更新:递归删除后更新右子树引用
208+
self.right = self.right.remove(label)
204209
return self.parent or self
205210

206211
def _remove_repair(self) -> None:
@@ -308,7 +313,6 @@ def check_color_properties(self) -> bool:
308313
return False
309314
# All properties were met
310315
return True
311-
312316
def check_coloring(self) -> bool:
313317
"""A helper function to recursively check Property 4 of a
314318
Red-Black Tree. See check_color_properties for more info.
@@ -320,24 +324,23 @@ def check_coloring(self) -> bool:
320324
return not (self.right and not self.right.check_coloring())
321325

322326
def black_height(self) -> int | None:
323-
"""Returns the number of black nodes from this node to the
324-
leaves of the tree, or None if there isn't one such value (the
325-
tree is color incorrectly).
326-
"""
327-
if self is None or self.left is None or self.right is None:
328-
# If we're already at a leaf, there is no path
327+
"""修正的黑色高度计算方法"""
328+
# 叶子节点(None)被视为黑色,高度为1
329+
if self is None:
329330
return 1
330-
left = RedBlackTree.black_height(self.left)
331-
right = RedBlackTree.black_height(self.right)
332-
if left is None or right is None:
333-
# There are issues with coloring below children nodes
331+
332+
# 递归计算左右子树高度
333+
left_bh = RedBlackTree.black_height(self.left)
334+
right_bh = RedBlackTree.black_height(self.right)
335+
336+
# 检查高度是否有效且一致
337+
if left_bh is None or right_bh is None:
334338
return None
335-
if left != right:
336-
# The two children have unequal depths
339+
if left_bh != right_bh:
337340
return None
338-
# Return the black depth of children, plus one if this node is
339-
# black
340-
return left + (1 - self.color)
341+
342+
# 返回当前节点高度(黑色节点+1)
343+
return left_bh + (1 - self.color)
341344

342345
# Here are functions which are general to all binary search trees
343346

@@ -364,7 +367,6 @@ def search(self, label: int) -> RedBlackTree | None:
364367
return None
365368
else:
366369
return self.left.search(label)
367-
368370
def floor(self, label: int) -> int | None:
369371
"""Returns the largest element in this tree which is at most label.
370372
This method is guaranteed to run in O(log(n)) time."""
@@ -437,7 +439,6 @@ def sibling(self) -> RedBlackTree | None:
437439
return self.parent.right
438440
else:
439441
return self.parent.left
440-
441442
def is_left(self) -> bool:
442443
"""Returns true iff this node is the left child of its parent."""
443444
if self.parent is None:
@@ -451,12 +452,13 @@ def is_right(self) -> bool:
451452
return self.parent.right is self
452453

453454
def __bool__(self) -> bool:
454-
return True
455+
"""空树返回False"""
456+
return self.label is not None
455457

456458
def __len__(self) -> int:
457-
"""
458-
Return the number of nodes in this tree.
459-
"""
459+
"""正确处理空树情况"""
460+
if self.label is None:
461+
return 0
460462
ln = 1
461463
if self.left:
462464
ln += len(self.left)
@@ -484,7 +486,6 @@ def postorder_traverse(self) -> Iterator[int | None]:
484486
if self.right:
485487
yield from self.right.postorder_traverse()
486488
yield self.label
487-
488489
def __repr__(self) -> str:
489490
if self.left is None and self.right is None:
490491
return f"'{self.label} {(self.color and 'red') or 'blk'}'"
@@ -502,18 +503,23 @@ def __eq__(self, other: object) -> bool:
502503
"""Test if two trees are equal."""
503504
if not isinstance(other, RedBlackTree):
504505
return NotImplemented
505-
if self.label == other.label:
506-
return self.left == other.left and self.right == other.right
507-
else:
506+
507+
# 处理空树比较
508+
if self.label is None and other.label is None:
509+
return True
510+
if self.label != other.label:
508511
return False
512+
513+
# 递归比较子树
514+
return (self.left == other.left) and (self.right == other.right)
515+
516+
# 明确表示该类的实例不可哈希
517+
__hash__ = None
509518

510519

511520
def color(node: RedBlackTree | None) -> int:
512521
"""Returns the color of a node, allowing for None leaves."""
513-
if node is None:
514-
return 0
515-
else:
516-
return node.color
522+
return 0 if node is None else node.color
517523

518524

519525
"""
@@ -555,7 +561,6 @@ def test_rotations() -> bool:
555561
right_rot.right.right.right = RedBlackTree(20, parent=right_rot.right.right)
556562
return tree == right_rot
557563

558-
559564
def test_insertion_speed() -> bool:
560565
"""Test that the tree balances inserts to O(log(n)) by doing a lot
561566
of them.
@@ -639,8 +644,6 @@ def test_floor_ceil() -> bool:
639644
if tree.floor(val) != floor or tree.ceil(val) != ceil:
640645
return False
641646
return True
642-
643-
644647
def test_min_max() -> bool:
645648
"""Tests the min and max functions in the tree."""
646649
tree = RedBlackTree(0)
@@ -668,7 +671,6 @@ def test_tree_traversal() -> bool:
668671
return False
669672
return list(tree.postorder_traverse()) == [-16, 8, 20, 24, 22, 16, 0]
670673

671-
672674
def test_tree_chaining() -> bool:
673675
"""Tests the three different tree chaining functions."""
674676
tree = RedBlackTree(0)
@@ -680,10 +682,37 @@ def test_tree_chaining() -> bool:
680682
return list(tree.postorder_traverse()) == [-16, 8, 20, 24, 22, 16, 0]
681683

682684

683-
def print_results(msg: str, passes: bool) -> None:
684-
print(str(msg), "works!" if passes else "doesn't work :(")
685+
def test_empty_tree() -> bool:
686+
"""Tests behavior with empty trees."""
687+
tree = RedBlackTree(None)
688+
689+
# 测试空树属性
690+
if tree.label is not None or tree.left or tree.right:
691+
return False
692+
693+
# 测试空树长度
694+
if len(tree) != 0:
695+
return False
696+
697+
# 测试空树布尔值
698+
if tree:
699+
return False
700+
701+
# 测试空树搜索
702+
if 0 in tree or tree.search(0):
703+
return False
704+
705+
# 测试空树删除
706+
try:
707+
tree.remove(0)
708+
except Exception:
709+
return False
710+
711+
return True
685712

686713

714+
def print_results(msg: str, passes: bool) -> None:
715+
print(str(msg), "works!" if passes else "doesn't work :(")
687716
def pytests() -> None:
688717
assert test_rotations()
689718
assert test_insert()
@@ -692,6 +721,7 @@ def pytests() -> None:
692721
assert test_floor_ceil()
693722
assert test_tree_traversal()
694723
assert test_tree_chaining()
724+
assert test_empty_tree()
695725

696726

697727
def main() -> None:
@@ -704,7 +734,8 @@ def main() -> None:
704734
print_results("Deleting", test_insert_delete())
705735
print_results("Floor and ceil", test_floor_ceil())
706736
print_results("Tree traversal", test_tree_traversal())
707-
print_results("Tree traversal", test_tree_chaining())
737+
print_results("Tree chaining", test_tree_chaining())
738+
print_results("Empty tree handling", test_empty_tree())
708739
print("Testing tree balancing...")
709740
print("This should only be a few seconds.")
710741
test_insertion_speed()

0 commit comments

Comments
 (0)