Skip to content

Commit 96ade05

Browse files
authored
Update skew_heap.py
1 parent 3361c55 commit 96ade05

File tree

1 file changed

+52
-49
lines changed

1 file changed

+52
-49
lines changed

data_structures/heap/skew_heap.py

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,22 @@
33
from __future__ import annotations
44

55
from collections.abc import Iterable, Iterator
6-
from typing import Generic, Protocol, TypeVar
6+
from typing import Any, Callable, Optional
77

88

9-
class Comparable(Protocol):
10-
def __lt__(self, other: object) -> bool: ...
11-
12-
13-
T = TypeVar("T", bound=Comparable)
14-
15-
16-
class SkewNode[T]:
9+
class SkewNode:
1710
"""
1811
One node of the skew heap. Contains the value and references to
1912
two children.
2013
"""
2114

22-
def __init__(self, value: T) -> None:
23-
self._value: T = value
24-
self.left: SkewNode[T] | None = None
25-
self.right: SkewNode[T] | None = None
15+
def __init__(self, value: Any) -> None:
16+
self._value: Any = value
17+
self.left: Optional[SkewNode] = None
18+
self.right: Optional[SkewNode] = None
2619

2720
@property
28-
def value(self) -> T:
21+
def value(self) -> Any:
2922
"""
3023
Return the value of the node.
3124
@@ -44,17 +37,20 @@ def value(self) -> T:
4437

4538
@staticmethod
4639
def merge(
47-
root1: SkewNode[T] | None, root2: SkewNode[T] | None
48-
) -> SkewNode[T] | None:
40+
root1: Optional[SkewNode],
41+
root2: Optional[SkewNode],
42+
comp: Callable[[Any, Any], bool]
43+
) -> Optional[SkewNode]:
4944
"""
5045
Merge two nodes together.
51-
>>> SkewNode.merge(SkewNode(10), SkewNode(-10.5)).value
46+
>>> def comp(a, b): return a < b
47+
>>> SkewNode.merge(SkewNode(10), SkewNode(-10.5), comp).value
5248
-10.5
53-
>>> SkewNode.merge(SkewNode(10), SkewNode(10.5)).value
49+
>>> SkewNode.merge(SkewNode(10), SkewNode(10.5), comp).value
5450
10
55-
>>> SkewNode.merge(SkewNode(10), SkewNode(10)).value
51+
>>> SkewNode.merge(SkewNode(10), SkewNode(10), comp).value
5652
10
57-
>>> SkewNode.merge(SkewNode(-100), SkewNode(-10.5)).value
53+
>>> SkewNode.merge(SkewNode(-100), SkewNode(-10.5), comp).value
5854
-100
5955
"""
6056
# Handle empty nodes
@@ -63,34 +59,23 @@ def merge(
6359
if not root2:
6460
return root1
6561

66-
# Compare values using explicit comparison function
67-
if SkewNode._is_less_than(root1.value, root2.value):
62+
# Compare values using provided comparison function
63+
if comp(root1.value, root2.value):
6864
# root1 is smaller, make it the new root
6965
result = root1
7066
temp = root1.right
7167
result.right = root1.left
72-
result.left = SkewNode.merge(temp, root2)
68+
result.left = SkewNode.merge(temp, root2, comp)
7369
return result
7470
else:
7571
# root2 is smaller or equal, use it as new root
7672
result = root2
7773
temp = root2.right
7874
result.right = root2.left
79-
result.left = SkewNode.merge(root1, temp)
75+
result.left = SkewNode.merge(root1, temp, comp)
8076
return result
8177

82-
@staticmethod
83-
def _is_less_than(a: T, b: T) -> bool:
84-
"""Safe comparison function that avoids type checker issues"""
85-
try:
86-
return a < b
87-
except TypeError:
88-
# Fallback comparison for non-comparable types
89-
# Uses string representation as last resort
90-
return str(a) < str(b)
91-
92-
93-
class SkewHeap[T]:
78+
class SkewHeap:
9479
"""
9580
A data structure that allows inserting a new value and popping the smallest
9681
values. Both operations take O(logN) time where N is the size of the heap.
@@ -113,15 +98,25 @@ class SkewHeap[T]:
11398
[-1, 0, 1]
11499
"""
115100

116-
def __init__(self, data: Iterable[T] | None = ()) -> None:
101+
def __init__(
102+
self,
103+
data: Iterable[Any] | None = None,
104+
comp: Callable[[Any, Any], bool] = lambda a, b: a < b
105+
) -> None:
117106
"""
118-
Initialize the skew heap with optional data
119-
107+
Initialize the skew heap with optional data and comparison function
108+
120109
>>> sh = SkewHeap([3, 1, 3, 7])
121110
>>> list(sh)
122111
[1, 3, 3, 7]
123-
"""
124-
self._root: SkewNode[T] | None = None
112+
113+
# Max-heap example
114+
>>> max_heap = SkewHeap([3, 1, 3, 7], comp=lambda a, b: a > b)
115+
>>> list(max_heap)
116+
[7, 3, 3, 1]
117+
"""
118+
self._root: Optional[SkewNode] = None
119+
self._comp = comp
125120
if data:
126121
for item in data:
127122
self.insert(item)
@@ -142,7 +137,7 @@ def __bool__(self) -> bool:
142137
"""
143138
return self._root is not None
144139

145-
def __iter__(self) -> Iterator[T]:
140+
def __iter__(self) -> Iterator[Any]:
146141
"""
147142
Iterate through all values in sorted order
148143
@@ -151,8 +146,8 @@ def __iter__(self) -> Iterator[T]:
151146
[1, 3, 3, 7]
152147
"""
153148
# Create a temporary heap for iteration
154-
temp_heap: SkewHeap[T] = SkewHeap()
155-
result: list[T] = []
149+
temp_heap = SkewHeap(comp=self._comp)
150+
result: list[Any] = []
156151

157152
# Pop all elements from the heap
158153
while self:
@@ -164,7 +159,7 @@ def __iter__(self) -> Iterator[T]:
164159
self._root = temp_heap._root
165160
return iter(result)
166161

167-
def insert(self, value: T) -> None:
162+
def insert(self, value: Any) -> None:
168163
"""
169164
Insert a new value into the heap
170165
@@ -176,9 +171,13 @@ def insert(self, value: T) -> None:
176171
>>> list(sh)
177172
[1, 3, 3, 7]
178173
"""
179-
self._root = SkewNode.merge(self._root, SkewNode(value))
174+
self._root = SkewNode.merge(
175+
self._root,
176+
SkewNode(value),
177+
self._comp
178+
)
180179

181-
def pop(self) -> T:
180+
def pop(self) -> Any:
182181
"""
183182
Remove and return the smallest value from the heap
184183
@@ -198,10 +197,14 @@ def pop(self) -> T:
198197
"""
199198
result = self.top()
200199
if self._root:
201-
self._root = SkewNode.merge(self._root.left, self._root.right)
200+
self._root = SkewNode.merge(
201+
self._root.left,
202+
self._root.right,
203+
self._comp
204+
)
202205
return result
203206

204-
def top(self) -> T:
207+
def top(self) -> Any:
205208
"""
206209
Return the smallest value without removing it
207210

0 commit comments

Comments
 (0)