1818from  typing  import  TYPE_CHECKING , Any , Callable , Optional , SupportsIndex , TypeVar , Union 
1919
2020from  torch  import  Tensor 
21- from  typing_extensions  import  Self 
21+ from  typing_extensions  import  Self ,  overload 
2222
2323import  lightning .pytorch  as  pl 
2424from  lightning .pytorch .callbacks  import  Checkpoint 
@@ -108,6 +108,7 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None:
108108
109109
110110_T  =  TypeVar ("_T" )
111+ _PT  =  TypeVar ("_PT" )
111112
112113
113114class  _ListMap (list [_T ]):
@@ -139,27 +140,28 @@ class _ListMap(list[_T]):
139140
140141    """ 
141142
142-     def  __init__ (self , __iterable : Union [Mapping [str , _T ], Iterable [_T ]] =  None ):
143+     def  __init__ (self , __iterable : Optional [ Union [Mapping [str , _T ], Iterable [_T ] ]] =  None ):
143144        if  isinstance (__iterable , Mapping ):
144145            # super inits list with values 
145146            if  any (not  isinstance (x , str ) for  x  in  __iterable ):
146147                raise  TypeError ("When providing a Mapping, all keys must be of type str." )
147148            super ().__init__ (__iterable .values ())
148-             self . _dict  =  dict (zip (__iterable .keys (), range (len (__iterable ))))
149+             _dict  =  dict (zip (__iterable .keys (), range (len (__iterable ))))
149150        else :
150151            default_dict  =  {}
151152            if  isinstance (__iterable , _ListMap ):
152153                default_dict  =  __iterable ._dict .copy ()
153154            super ().__init__ (() if  __iterable  is  None  else  __iterable )
154-             self ._dict : dict  =  default_dict 
155+             _dict : dict  =  default_dict 
156+         self ._dict  =  _dict 
155157
156158    def  __eq__ (self , other : Any ) ->  bool :
157159        list_eq  =  list .__eq__ (self , other )
158160        if  isinstance (other , _ListMap ):
159161            return  list_eq  and  self ._dict  ==  other ._dict 
160162        return  list_eq 
161163
162-     def  copy (self ):
164+     def  copy (self )  ->   Self :
163165        new_listmap  =  _ListMap (self )
164166        new_listmap ._dict  =  self ._dict .copy ()
165167        return  new_listmap 
@@ -171,7 +173,16 @@ def extend(self, __iterable: Iterable[_T]) -> None:
171173                self ._dict [key ] =  idx  +  offset 
172174        super ().extend (__iterable )
173175
174-     def  pop (self , key : Union [SupportsIndex , str ] =  - 1 , default : Optional [Any ] =  None ) ->  _T :
176+     @overload  
177+     def  pop (self , key : SupportsIndex  =  - 1 , / ) ->  _T : ...
178+ 
179+     @overload  
180+     def  pop (self , key : str , / , default : _T ) ->  _T : ...
181+ 
182+     @overload  
183+     def  pop (self , key : str , default : _PT , / ) ->  Union [_T , _PT ]: ...
184+ 
185+     def  pop (self , key = - 1 , default = None ):
175186        if  isinstance (key , int ):
176187            ret  =  list .pop (self , key )
177188            for  str_key , idx  in  list (self ._dict .items ()):
@@ -211,7 +222,7 @@ def sort(
211222        reverse : bool  =  False ,
212223    ) ->  None :
213224        # Create a mapping from item to its name(s) 
214-         item_to_names  =  {}
225+         item_to_names :  dict [ _T ,  list [ int ]]  =  {}
215226        for  name , idx  in  self ._dict .items ():
216227            item  =  self [idx ]
217228            item_to_names .setdefault (item , []).append (name )
@@ -225,8 +236,13 @@ def sort(
225236                    new_dict [name ] =  idx 
226237        self ._dict  =  new_dict 
227238
228-     # --- List-like interface --- 
229-     def  __getitem__ (self , key : Union [int , slice , str ]) ->  _T :
239+     @overload  
240+     def  __getitem__ (self , key : Union [SupportsIndex , str ], / ) ->  _T : ...
241+ 
242+     @overload  
243+     def  __getitem__ (self , key : slice , / ) ->  list [_T ]: ...
244+ 
245+     def  __getitem__ (self , key , / ):
230246        if  isinstance (key , str ):
231247            return  self [self ._dict [key ]]
232248        return  list .__getitem__ (self , key )
@@ -245,7 +261,13 @@ def __iadd__(self, other: Union[list[_T], Self]) -> Self:
245261
246262        return  super ().__iadd__ (other )
247263
248-     def  __setitem__ (self , key : Union [SupportsIndex , slice , str ], value : _T ) ->  None :
264+     @overload  
265+     def  __setitem__ (self , key : Union [SupportsIndex , str ], value : _T , / ) ->  None : ...
266+ 
267+     @overload  
268+     def  __setitem__ (self , key : slice , value : Iterable [_T ], / ) ->  None : ...
269+ 
270+     def  __setitem__ (self , key , value , / ) ->  None :
249271        if  isinstance (key , (int , slice )):
250272            # replace element by index 
251273            return  list .__setitem__ (self , key , value )
@@ -259,14 +281,14 @@ def __setitem__(self, key: Union[SupportsIndex, slice, str], value: _T) -> None:
259281            return  None 
260282        raise  TypeError ("Key must be int or str" )
261283
262-     def  __contains__ (self , item : Union [_T , str ]) ->  bool :
284+     def  __contains__ (self , item : Union [object , str ]) ->  bool :
263285        if  isinstance (item , str ):
264286            return  item  in  self ._dict 
265287        return  list .__contains__ (self , item )
266288
267289    # --- Dict-like interface --- 
268290
269-     def  __delitem__ (self , key : Union [int , slice , str ]) ->  None :
291+     def  __delitem__ (self , key : Union [SupportsIndex , slice , str ]) ->  None :
270292        if  isinstance (key , (int , slice )):
271293            list .__delitem__ (self , key )
272294            for  _key  in  key .indices (len (self )) if  isinstance (key , slice ) else  [key ]:
@@ -294,7 +316,13 @@ def items(self) -> ItemsView[str, _T]:
294316        d  =  {k : self [v ] for  k , v  in  self ._dict .items ()}
295317        return  d .items ()
296318
297-     def  get (self , __key : str , default : Optional [Any ] =  None ) ->  _T :
319+     @overload  
320+     def  get (self , __key : str ) ->  Optional [_T ]: ...
321+ 
322+     @overload  
323+     def  get (self , __key : str , default : _PT ) ->  Union [_T , _PT ]: ...
324+ 
325+     def  get (self , __key : str , default = None ):
298326        if  __key  in  self ._dict :
299327            return  self [self ._dict [__key ]]
300328        return  default 
@@ -308,6 +336,6 @@ def reverse(self) -> None:
308336            self ._dict [key ] =  len (self ) -  1  -  idx 
309337        list .reverse (self )
310338
311-     def  clear (self ):
339+     def  clear (self )  ->   None :
312340        self ._dict .clear ()
313341        list .clear (self )
0 commit comments