Skip to content

Commit e013166

Browse files
authored
Update skew_heap.py
1 parent 283601d commit e013166

File tree

1 file changed

+20
-23
lines changed

1 file changed

+20
-23
lines changed

data_structures/heap/skew_heap.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from __future__ import annotations
44

55
from collections.abc import Iterable, Iterator
6-
from typing import Any, Protocol, TypeVar
6+
from typing import Protocol, TypeVar
77

88

99
class Comparable(Protocol):
10-
def __lt__(self: T, other: T) -> bool: ...
10+
def __lt__(self, other: Any) -> bool: ...
1111

1212

1313
T = TypeVar("T", bound=Comparable)
@@ -59,11 +59,8 @@ def value(self) -> T:
5959
TypeError: SkewNode.__init__() missing 1 required positional argument: 'value'
6060
"""
6161
return self._value
62-
6362
@staticmethod
64-
def merge(
65-
root1: SkewNode[T] | None, root2: SkewNode[T] | None
66-
) -> SkewNode[T] | None:
63+
def merge(root1: SkewNode[T] | None, root2: SkewNode[T] | None) -> SkewNode[T] | None:
6764
"""
6865
Merge 2 nodes together.
6966
>>> SkewNode.merge(SkewNode(10), SkewNode(-10.5)).value
@@ -80,22 +77,23 @@ def merge(
8077

8178
if not root2:
8279
return root1
83-
84-
# 使用类型安全的比较方式
85-
if root1.value < root2.value:
86-
# root1 更小,不需要交换
87-
result = root1
88-
temp = root1.right
89-
result.right = root1.left
90-
result.left = SkewNode.merge(temp, root2)
91-
return result
92-
else:
93-
# root2 更小或相等,需要交换
94-
result = root2
95-
temp = root2.right
96-
result.right = root2.left
97-
result.left = SkewNode.merge(root1, temp)
98-
return result
80+
try:
81+
if root1.value < root2.value:
82+
result = root1
83+
temp = root1.right
84+
result.right = root1.left
85+
result.left = SkewNode.merge(temp, root2)
86+
return result
87+
except TypeError:
88+
# 回退到值比较
89+
pass
90+
91+
# 如果比较失败或 root2 更小
92+
result = root2
93+
temp = root2.right
94+
result.right = root2.left
95+
result.left = SkewNode.merge(root1, temp)
96+
return result
9997

10098

10199
class SkewHeap[T]:
@@ -206,7 +204,6 @@ def pop(self) -> T:
206204
self._root = None
207205

208206
return result
209-
210207
def top(self) -> T:
211208
"""
212209
Return the smallest value from the heap.

0 commit comments

Comments
 (0)