33from __future__ import annotations
44
55from collections .abc import Iterable , Iterator
6- from typing import Generic , Protocol , TypeVar
6+ from typing import Any , Callable , Optional
77
88
9- class Comparable (Protocol ):
10- def __lt__ (self , other : object ) -> bool : ...
11-
12-
13- T = TypeVar ("T" , bound = Comparable )
14-
15-
16- class SkewNode [T ]:
9+ class SkewNode :
1710 """
1811 One node of the skew heap. Contains the value and references to
1912 two children.
2013 """
2114
22- def __init__ (self , value : T ) -> None :
23- self ._value : T = value
24- self .left : SkewNode [ T ] | None = None
25- self .right : SkewNode [ T ] | None = None
15+ def __init__ (self , value : Any ) -> None :
16+ self ._value : Any = value
17+ self .left : Optional [ SkewNode ] = None
18+ self .right : Optional [ SkewNode ] = None
2619
2720 @property
28- def value (self ) -> T :
21+ def value (self ) -> Any :
2922 """
3023 Return the value of the node.
3124
@@ -44,17 +37,20 @@ def value(self) -> T:
4437
4538 @staticmethod
4639 def merge (
47- root1 : SkewNode [T ] | None , root2 : SkewNode [T ] | None
48- ) -> SkewNode [T ] | None :
40+ root1 : Optional [SkewNode ],
41+ root2 : Optional [SkewNode ],
42+ comp : Callable [[Any , Any ], bool ]
43+ ) -> Optional [SkewNode ]:
4944 """
5045 Merge two nodes together.
51- >>> SkewNode.merge(SkewNode(10), SkewNode(-10.5)).value
46+ >>> def comp(a, b): return a < b
47+ >>> SkewNode.merge(SkewNode(10), SkewNode(-10.5), comp).value
5248 -10.5
53- >>> SkewNode.merge(SkewNode(10), SkewNode(10.5)).value
49+ >>> SkewNode.merge(SkewNode(10), SkewNode(10.5), comp ).value
5450 10
55- >>> SkewNode.merge(SkewNode(10), SkewNode(10)).value
51+ >>> SkewNode.merge(SkewNode(10), SkewNode(10), comp ).value
5652 10
57- >>> SkewNode.merge(SkewNode(-100), SkewNode(-10.5)).value
53+ >>> SkewNode.merge(SkewNode(-100), SkewNode(-10.5), comp ).value
5854 -100
5955 """
6056 # Handle empty nodes
@@ -63,34 +59,23 @@ def merge(
6359 if not root2 :
6460 return root1
6561
66- # Compare values using explicit comparison function
67- if SkewNode . _is_less_than (root1 .value , root2 .value ):
62+ # Compare values using provided comparison function
63+ if comp (root1 .value , root2 .value ):
6864 # root1 is smaller, make it the new root
6965 result = root1
7066 temp = root1 .right
7167 result .right = root1 .left
72- result .left = SkewNode .merge (temp , root2 )
68+ result .left = SkewNode .merge (temp , root2 , comp )
7369 return result
7470 else :
7571 # root2 is smaller or equal, use it as new root
7672 result = root2
7773 temp = root2 .right
7874 result .right = root2 .left
79- result .left = SkewNode .merge (root1 , temp )
75+ result .left = SkewNode .merge (root1 , temp , comp )
8076 return result
8177
82- @staticmethod
83- def _is_less_than (a : T , b : T ) -> bool :
84- """Safe comparison function that avoids type checker issues"""
85- try :
86- return a < b
87- except TypeError :
88- # Fallback comparison for non-comparable types
89- # Uses string representation as last resort
90- return str (a ) < str (b )
91-
92-
93- class SkewHeap [T ]:
78+ class SkewHeap :
9479 """
9580 A data structure that allows inserting a new value and popping the smallest
9681 values. Both operations take O(logN) time where N is the size of the heap.
@@ -113,15 +98,25 @@ class SkewHeap[T]:
11398 [-1, 0, 1]
11499 """
115100
116- def __init__ (self , data : Iterable [T ] | None = ()) -> None :
101+ def __init__ (
102+ self ,
103+ data : Iterable [Any ] | None = None ,
104+ comp : Callable [[Any , Any ], bool ] = lambda a , b : a < b
105+ ) -> None :
117106 """
118- Initialize the skew heap with optional data
119-
107+ Initialize the skew heap with optional data and comparison function
108+
120109 >>> sh = SkewHeap([3, 1, 3, 7])
121110 >>> list(sh)
122111 [1, 3, 3, 7]
123- """
124- self ._root : SkewNode [T ] | None = None
112+
113+ # Max-heap example
114+ >>> max_heap = SkewHeap([3, 1, 3, 7], comp=lambda a, b: a > b)
115+ >>> list(max_heap)
116+ [7, 3, 3, 1]
117+ """
118+ self ._root : Optional [SkewNode ] = None
119+ self ._comp = comp
125120 if data :
126121 for item in data :
127122 self .insert (item )
@@ -142,7 +137,7 @@ def __bool__(self) -> bool:
142137 """
143138 return self ._root is not None
144139
145- def __iter__ (self ) -> Iterator [T ]:
140+ def __iter__ (self ) -> Iterator [Any ]:
146141 """
147142 Iterate through all values in sorted order
148143
@@ -151,8 +146,8 @@ def __iter__(self) -> Iterator[T]:
151146 [1, 3, 3, 7]
152147 """
153148 # Create a temporary heap for iteration
154- temp_heap : SkewHeap [ T ] = SkewHeap ()
155- result : list [T ] = []
149+ temp_heap = SkewHeap (comp = self . _comp )
150+ result : list [Any ] = []
156151
157152 # Pop all elements from the heap
158153 while self :
@@ -164,7 +159,7 @@ def __iter__(self) -> Iterator[T]:
164159 self ._root = temp_heap ._root
165160 return iter (result )
166161
167- def insert (self , value : T ) -> None :
162+ def insert (self , value : Any ) -> None :
168163 """
169164 Insert a new value into the heap
170165
@@ -176,9 +171,13 @@ def insert(self, value: T) -> None:
176171 >>> list(sh)
177172 [1, 3, 3, 7]
178173 """
179- self ._root = SkewNode .merge (self ._root , SkewNode (value ))
174+ self ._root = SkewNode .merge (
175+ self ._root ,
176+ SkewNode (value ),
177+ self ._comp
178+ )
180179
181- def pop (self ) -> T :
180+ def pop (self ) -> Any :
182181 """
183182 Remove and return the smallest value from the heap
184183
@@ -198,10 +197,14 @@ def pop(self) -> T:
198197 """
199198 result = self .top ()
200199 if self ._root :
201- self ._root = SkewNode .merge (self ._root .left , self ._root .right )
200+ self ._root = SkewNode .merge (
201+ self ._root .left ,
202+ self ._root .right ,
203+ self ._comp
204+ )
202205 return result
203206
204- def top (self ) -> T :
207+ def top (self ) -> Any :
205208 """
206209 Return the smallest value without removing it
207210
0 commit comments