1- # pyright: reportIncompatibleMethodOverride=false
2- # pyright: reportMissingTypeStubs=false
3-
41import contextlib
5- from collections .abc import (
6- Collection ,
7- Iterable ,
8- Iterator ,
9- Mapping ,
10- MutableMapping ,
11- )
12- from typing import Generic , Optional , Protocol , TypeVar , Union , cast , overload
2+ from collections .abc import Mapping , MutableMapping
133
144from sortedcontainers import SortedDict
15- from sortedcontainers .sorteddict import ItemsView , KeysView , ValuesView
165
176from .const import Bound
187from .interval import Interval
198
209
21- def _sortkey (i : tuple [ Interval , object ]) -> tuple [ object , bool ] :
10+ def _sortkey (i ) :
2211 # Sort by lower bound, closed first
2312 return (i [0 ].lower , i [0 ].left is Bound .OPEN )
2413
2514
26- V = TypeVar ("V" )
27-
28-
29- class HowToCombineSingle (Protocol ):
30- def __call__ (self , x : V , y : V ) -> V : ...
31-
32-
33- class HowToCombineWithInterval (Protocol ):
34- def __call__ (self , x : V , y : V , i : Interval ) -> V : ...
35-
36-
37- class IntervalDict (Generic [V ], MutableMapping [object , V ]):
15+ class IntervalDict (MutableMapping ):
3816 """
3917 An IntervalDict is a dict-like data structure that maps from intervals to data,where
4018 keys can be single values or Interval instances.
@@ -52,17 +30,12 @@ class IntervalDict(Generic[V], MutableMapping[object, V]):
5230 values (not keys) that are stored.
5331 """
5432
55- __slots__ : tuple [ str , ...] = ("_storage" ,)
33+ __slots__ = ("_storage" ,)
5634
5735 # Class to use when creating Interval instances
58- _klass : type = Interval
36+ _klass = Interval
5937
60- def __init__ (
61- self ,
62- mapping_or_iterable : Union [
63- Mapping [object , V ], Iterable [tuple [object , V ]], None
64- ] = None ,
65- ):
38+ def __init__ (self , mapping_or_iterable = None ):
6639 """
6740 Return a new IntervalDict.
6841
@@ -73,15 +46,13 @@ def __init__(
7346
7447 :param mapping_or_iterable: optional mapping or iterable.
7548 """
76- self ._storage : SortedDict = SortedDict (
77- _sortkey
78- ) # Mapping from intervals to values
49+ self ._storage = SortedDict (_sortkey ) # Mapping from intervals to values
7950
8051 if mapping_or_iterable is not None :
8152 self .update (mapping_or_iterable )
8253
8354 @classmethod
84- def _from_items (cls , items : Collection [ tuple [ object , V ]] ):
55+ def _from_items (cls , items ):
8556 """
8657 Fast creation of an IntervalDict with the provided items.
8758
@@ -111,17 +82,7 @@ def copy(self):
11182 """
11283 return self .__class__ ._from_items (self .items ())
11384
114- @overload
115- def get (
116- self , key : Interval , default : Optional [V ] = None
117- ) -> "IntervalDict[V] | None" : ...
118-
119- @overload
120- def get (self , key : object , default : V = None ) -> Optional [V ]: ...
121-
122- def get (
123- self , key : Union [object , Interval ], default : Optional [V ] = None
124- ) -> Union ["IntervalDict[V]" , V , None ]:
85+ def get (self , key , default = None ):
12586 """
12687 Return the values associated to given key.
12788
@@ -144,26 +105,17 @@ def get(
144105 except KeyError :
145106 return default
146107
147- def find (self , value : V ) -> Interval :
108+ def find (self , value ) :
148109 """
149110 Return a (possibly empty) Interval i such that self[i] = value, and
150111 self[~i] != value.
151112
152113 :param value: value to look for.
153114 :return: an Interval instance.
154115 """
155- return cast (
156- Interval ,
157- self ._klass (
158- * (
159- i
160- for i , v in cast (ItemsView [Interval , V ], self ._storage .items ())
161- if v == value
162- )
163- ),
164- )
116+ return self ._klass (* (i for i , v in self ._storage .items () if v == value ))
165117
166- def items (self ) -> ItemsView [ Interval , V ] :
118+ def items (self ):
167119 """
168120 Return a view object on the contained items sorted by their key
169121 (see https://docs.python.org/3/library/stdtypes.html#dict-views).
@@ -172,7 +124,7 @@ def items(self) -> ItemsView[Interval, V]:
172124 """
173125 return self ._storage .items ()
174126
175- def keys (self ) -> KeysView [ Interval ] :
127+ def keys (self ):
176128 """
177129 Return a view object on the contained keys (sorted)
178130 (see https://docs.python.org/3/library/stdtypes.html#dict-views).
@@ -181,7 +133,7 @@ def keys(self) -> KeysView[Interval]:
181133 """
182134 return self ._storage .keys ()
183135
184- def values (self ) -> ValuesView [ V ] :
136+ def values (self ):
185137 """
186138 Return a view object on the contained values sorted by their key
187139 (see https://docs.python.org/3/library/stdtypes.html#dict-views).
@@ -190,23 +142,15 @@ def values(self) -> ValuesView[V]:
190142 """
191143 return self ._storage .values ()
192144
193- def domain (self ) -> Interval :
145+ def domain (self ):
194146 """
195147 Return an Interval corresponding to the domain of this IntervalDict.
196148
197149 :return: an Interval.
198150 """
199- return cast ( Interval , self ._klass (* self ._storage .keys () ))
151+ return self ._klass (* self ._storage .keys ())
200152
201- @overload
202- def pop (self , key : Interval , default : Optional [V ] = None ) -> "IntervalDict[V]" : ...
203-
204- @overload
205- def pop (self , key : object , default : Optional [V ] = None ) -> Optional [V ]: ...
206-
207- def pop (
208- self , key : object , default : Optional [V ] = None
209- ) -> Union ["IntervalDict[V]" , V , None ]:
153+ def pop (self , key , default = None ):
210154 """
211155 Remove key and return the corresponding value if key is not an Interval.
212156 If key is an interval, it returns an IntervalDict instance.
@@ -229,26 +173,16 @@ def pop(
229173 del self [key ]
230174 return value
231175
232- def popitem (self ) -> tuple [ Interval , V ] :
176+ def popitem (self ):
233177 """
234178 Remove and return some (key, value) pair as a 2-tuple.
235179 Raise KeyError if D is empty.
236180
237181 :return: a (key, value) pair.
238182 """
239- return cast (tuple [Interval , V ], self ._storage .popitem ())
240-
241- @overload
242- def setdefault (
243- self , key : Interval , default : Optional [V ] = None
244- ) -> "IntervalDict[V]" : ...
183+ return self ._storage .popitem ()
245184
246- @overload
247- def setdefault (self , key : object , default : Optional [V ] = None ) -> V : ...
248-
249- def setdefault (
250- self , key : object , default : Optional [V ] = None
251- ) -> Union [V , "IntervalDict[V]" , None ]:
185+ def setdefault (self , key , default = None ):
252186 """
253187 Return given key. If it does not exist, set its value to given default
254188 and return it.
@@ -259,23 +193,16 @@ def setdefault(
259193 """
260194 if isinstance (key , Interval ):
261195 value = self .get (key , default )
262- if value is not None :
263- self .update (value )
196+ self .update (value )
264197 return value
265198 else :
266199 try :
267200 return self [key ]
268201 except KeyError :
269- if default is not None :
270- self [key ] = default
202+ self [key ] = default
271203 return default
272204
273- def update (
274- self ,
275- mapping_or_iterable : Union [
276- Mapping [object , V ], Iterable [tuple [object , V ]], type ["IntervalDict[V]" ]
277- ],
278- ):
205+ def update (self , mapping_or_iterable ):
279206 """
280207 Update current IntervalDict with provided values.
281208
@@ -286,21 +213,14 @@ def update(
286213 :param mapping_or_iterable: mapping or iterable.
287214 """
288215 if isinstance (mapping_or_iterable , Mapping ):
289- data = cast ( ItemsView [ Interval , V ], mapping_or_iterable .items () )
216+ data = mapping_or_iterable .items ()
290217 else :
291218 data = mapping_or_iterable
292219
293- for i , v in cast ( Collection [ tuple [ object , V ]], data ) :
220+ for i , v in data :
294221 self [i ] = v
295222
296- def combine (
297- self ,
298- other : "IntervalDict[V]" ,
299- how : Union [HowToCombineSingle , HowToCombineWithInterval ],
300- * ,
301- missing : V = ...,
302- pass_interval : bool = False ,
303- ) -> "IntervalDict[V]" :
223+ def combine (self , other , how , * , missing = ..., pass_interval = False ):
304224 """
305225 Return a new IntervalDict that combines the values from current and
306226 provided IntervalDict.
@@ -326,12 +246,10 @@ def combine(
326246 new_items = []
327247
328248 if not pass_interval :
329- how = cast (HowToCombineSingle , how )
330249
331- def _how (x : V , y : V , i : Interval ) -> V :
250+ def _how (x , y , i ) :
332251 return how (x , y )
333252 else :
334- how = cast (HowToCombineWithInterval , how )
335253 _how = how
336254 dom1 , dom2 = self .domain (), other .domain ()
337255
@@ -356,7 +274,7 @@ def _how(x: V, y: V, i: Interval) -> V:
356274
357275 return self .__class__ (new_items )
358276
359- def as_dict (self , atomic : bool = False ) -> dict [ Interval , V ] :
277+ def as_dict (self , atomic = False ):
360278 """
361279 Return the content as a classical Python dict.
362280
@@ -372,16 +290,10 @@ def as_dict(self, atomic: bool = False) -> dict[Interval, V]:
372290 else :
373291 return dict (self ._storage )
374292
375- @overload
376- def __getitem__ (self , key : Interval ) -> "IntervalDict[V]" : ...
377-
378- @overload
379- def __getitem__ (self , key : object ) -> V : ...
380-
381- def __getitem__ (self , key : Union [object , Interval ]) -> Union [V , "IntervalDict[V]" ]:
293+ def __getitem__ (self , key ):
382294 if isinstance (key , Interval ):
383295 items = []
384- for i , v in cast ( ItemsView [ Interval , V ], self ._storage .items () ):
296+ for i , v in self ._storage .items ():
385297 # Early out
386298 if key .upper < i .lower :
387299 break
@@ -391,21 +303,19 @@ def __getitem__(self, key: Union[object, Interval]) -> Union[V, "IntervalDict[V]
391303 items .append ((intersection , v ))
392304 return self .__class__ ._from_items (items )
393305 else :
394- for i , v in cast ( ItemsView [ Interval , V ], self ._storage .items () ):
306+ for i , v in self ._storage .items ():
395307 # Early out
396308 if key < i .lower :
397309 break
398310 if key in i :
399311 return v
400312 raise KeyError (key )
401313
402- def __setitem__ (self , key : Union [ object , Interval ], value : Optional [ V ] ):
314+ def __setitem__ (self , key , value ):
403315 if isinstance (key , Interval ):
404316 interval = key
405317 else :
406- interval = cast (
407- Interval , self ._klass .from_atomic (Bound .CLOSED , key , key , Bound .CLOSED )
408- )
318+ interval = self ._klass .from_atomic (Bound .CLOSED , key , key , Bound .CLOSED )
409319
410320 if interval .empty :
411321 return
@@ -414,7 +324,7 @@ def __setitem__(self, key: Union[object, Interval], value: Optional[V]):
414324 added_items = []
415325
416326 found = False
417- for i , v in cast ( ItemsView [ Interval , V ], self ._storage .items () ):
327+ for i , v in self ._storage .items ():
418328 if value == v :
419329 found = True
420330 # Extend existing key
@@ -437,7 +347,7 @@ def __setitem__(self, key: Union[object, Interval], value: Optional[V]):
437347 for key , value in added_items :
438348 self ._storage [key ] = value
439349
440- def __delitem__ (self , key : Union [ object , Interval ] ):
350+ def __delitem__ (self , key ):
441351 if isinstance (key , Interval ):
442352 interval = key
443353 else :
@@ -472,31 +382,31 @@ def __delitem__(self, key: Union[object, Interval]):
472382 for key , value in added_items :
473383 self ._storage [key ] = value
474384
475- def __or__ (self , other : "IntervalDict[V]" ) -> "IntervalDict[V]" :
385+ def __or__ (self , other ) :
476386 d = self .copy ()
477387 d .update (other )
478388 return d
479389
480- def __ior__ (self , other : "IntervalDict[V]" ) -> "IntervalDict[V]" :
390+ def __ior__ (self , other ) :
481391 self .update (other )
482392 return self
483393
484- def __iter__ (self ) -> Iterator [ object ] :
394+ def __iter__ (self ):
485395 return iter (self ._storage )
486396
487- def __len__ (self ) -> int :
397+ def __len__ (self ):
488398 return len (self ._storage )
489399
490- def __contains__ (self , key : object ) -> bool :
400+ def __contains__ (self , key ) :
491401 return key in self .domain ()
492402
493403 def __repr__ (self ):
494404 return "{{{}}}" .format (
495405 ", " .join (f"{ i !r} : { v !r} " for i , v in self .items ()),
496406 )
497407
498- def __eq__ (self , other : object ) -> bool :
408+ def __eq__ (self , other ) :
499409 if isinstance (other , IntervalDict ):
500410 return self .as_dict () == other .as_dict ()
501-
502- return NotImplemented
411+ else :
412+ return NotImplemented
0 commit comments