|
7 | 7 |
|
8 | 8 |
|
9 | 9 | class Comparable(Protocol): |
10 | | - def __lt__(self, other: Any) -> bool: ... |
11 | | - def __gt__(self, other: Any) -> bool: ... |
| 10 | + def __lt__(self: T, other: T) -> bool: ... |
12 | 11 |
|
13 | 12 |
|
14 | 13 | T = TypeVar("T", bound=Comparable) |
@@ -60,24 +59,40 @@ def value(self) -> T: |
60 | 59 | TypeError: SkewNode.__init__() missing 1 required positional argument: 'value' |
61 | 60 | """ |
62 | 61 | return self._value |
63 | | - |
64 | 62 | @staticmethod |
65 | | - def merge( |
66 | | - root1: SkewNode[T] | None, root2: SkewNode[T] | None |
67 | | - ) -> SkewNode[T] | None: |
| 63 | + def merge(root1: SkewNode[T] | None, root2: SkewNode[T] | None) -> SkewNode[T] | None: |
| 64 | + """ |
| 65 | + Merge 2 nodes together. |
| 66 | + >>> SkewNode.merge(SkewNode(10), SkewNode(-10.5)).value |
| 67 | + -10.5 |
| 68 | + >>> SkewNode.merge(SkewNode(10), SkewNode(10.5)).value |
| 69 | + 10 |
| 70 | + >>> SkewNode.merge(SkewNode(10), SkewNode(10)).value |
| 71 | + 10 |
| 72 | + >>> SkewNode.merge(SkewNode(-100), SkewNode(-10.5)).value |
| 73 | + -100 |
| 74 | + """ |
68 | 75 | if not root1: |
69 | 76 | return root2 |
| 77 | + |
70 | 78 | if not root2: |
71 | 79 | return root1 |
72 | 80 |
|
73 | | - if root2.value < root1.value: |
74 | | - root1, root2 = root2, root1 |
75 | | - |
76 | | - result = root1 |
77 | | - temp = root1.right |
78 | | - result.right = root1.left |
79 | | - result.left = SkewNode.merge(temp, root2) |
80 | | - return result |
| 81 | + # 使用类型安全的比较方式 |
| 82 | + if root1.value < root2.value: |
| 83 | + # root1 更小,不需要交换 |
| 84 | + result = root1 |
| 85 | + temp = root1.right |
| 86 | + result.right = root1.left |
| 87 | + result.left = SkewNode.merge(temp, root2) |
| 88 | + return result |
| 89 | + else: |
| 90 | + # root2 更小或相等,需要交换 |
| 91 | + result = root2 |
| 92 | + temp = root2.right |
| 93 | + result.right = root2.left |
| 94 | + result.left = SkewNode.merge(root1, temp) |
| 95 | + return result |
81 | 96 |
|
82 | 97 |
|
83 | 98 | class SkewHeap[T]: |
@@ -182,9 +197,10 @@ def pop(self) -> T: |
182 | 197 | IndexError: Can't get top element for the empty heap. |
183 | 198 | """ |
184 | 199 | result = self.top() |
185 | | - self._root = ( |
186 | | - SkewNode.merge(self._root.left, self._root.right) if self._root else None |
187 | | - ) |
| 200 | + if self._root: |
| 201 | + self._root = SkewNode.merge(self._root.left, self._root.right) |
| 202 | + else: |
| 203 | + self._root = None |
188 | 204 |
|
189 | 205 | return result |
190 | 206 |
|
|
0 commit comments