1818from typing import TYPE_CHECKING , Any , Callable , Optional , SupportsIndex , TypeVar , Union
1919
2020from torch import Tensor
21- from typing_extensions import Self
21+ from typing_extensions import Self , overload
2222
2323import lightning .pytorch as pl
2424from lightning .pytorch .callbacks import Checkpoint
@@ -108,6 +108,7 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None:
108108
109109
110110_T = TypeVar ("_T" )
111+ _PT = TypeVar ("_PT" )
111112
112113
113114class _ListMap (list [_T ]):
@@ -139,19 +140,20 @@ class _ListMap(list[_T]):
139140
140141 """
141142
142- def __init__ (self , __iterable : Union [Mapping [str , _T ], Iterable [_T ]] = None ):
143+ def __init__ (self , __iterable : Optional [ Union [Mapping [str , _T ], Iterable [_T ] ]] = None ):
143144 if isinstance (__iterable , Mapping ):
144145 # super inits list with values
145146 if any (not isinstance (x , str ) for x in __iterable ):
146147 raise TypeError ("When providing a Mapping, all keys must be of type str." )
147148 super ().__init__ (__iterable .values ())
148- self . _dict = dict (zip (__iterable .keys (), range (len (__iterable ))))
149+ _dict = dict (zip (__iterable .keys (), range (len (__iterable ))))
149150 else :
150151 default_dict = {}
151152 if isinstance (__iterable , _ListMap ):
152153 default_dict = __iterable ._dict .copy ()
153154 super ().__init__ (() if __iterable is None else __iterable )
154- self ._dict : dict = default_dict
155+ _dict : dict = default_dict
156+ self ._dict = _dict
155157
156158 def __eq__ (self , other : Any ) -> bool :
157159 list_eq = list .__eq__ (self , other )
@@ -171,7 +173,7 @@ def extend(self, __iterable: Iterable[_T]) -> None:
171173 self ._dict [key ] = idx + offset
172174 super ().extend (__iterable )
173175
174- def pop (self , key : Union [SupportsIndex , str ] = - 1 , default : Optional [Any ] = None ) -> _T :
176+ def pop (self , key : Union [SupportsIndex , str ] = - 1 , default : Optional [_PT ] = None ) -> Union [ _T , _PT ] :
175177 if isinstance (key , int ):
176178 ret = list .pop (self , key )
177179 for str_key , idx in list (self ._dict .items ()):
@@ -211,7 +213,7 @@ def sort(
211213 reverse : bool = False ,
212214 ) -> None :
213215 # Create a mapping from item to its name(s)
214- item_to_names = {}
216+ item_to_names : dict [ _T , list [ int ]] = {}
215217 for name , idx in self ._dict .items ():
216218 item = self [idx ]
217219 item_to_names .setdefault (item , []).append (name )
@@ -225,8 +227,13 @@ def sort(
225227 new_dict [name ] = idx
226228 self ._dict = new_dict
227229
228- # --- List-like interface ---
229- def __getitem__ (self , key : Union [int , slice , str ]) -> _T :
230+ @overload
231+ def __getitem__ (self , key : Union [SupportsIndex , str ], / ) -> _T : ...
232+
233+ @overload
234+ def __getitem__ (self , key : slice , / ) -> list [_T ]: ...
235+
236+ def __getitem__ (self , key , / ):
230237 if isinstance (key , str ):
231238 return self [self ._dict [key ]]
232239 return list .__getitem__ (self , key )
@@ -245,7 +252,13 @@ def __iadd__(self, other: Union[list[_T], Self]) -> Self:
245252
246253 return super ().__iadd__ (other )
247254
248- def __setitem__ (self , key : Union [SupportsIndex , slice , str ], value : _T ) -> None :
255+ @overload
256+ def __setitem__ (self , key : Union [SupportsIndex , str ], value : _T , / ) -> None : ...
257+
258+ @overload
259+ def __setitem__ (self , key : slice , value : Iterable [_T ], / ) -> None : ...
260+
261+ def __setitem__ (self , key , value , / ) -> None :
249262 if isinstance (key , (int , slice )):
250263 # replace element by index
251264 return list .__setitem__ (self , key , value )
@@ -259,7 +272,7 @@ def __setitem__(self, key: Union[SupportsIndex, slice, str], value: _T) -> None:
259272 return None
260273 raise TypeError ("Key must be int or str" )
261274
262- def __contains__ (self , item : Union [_T , str ]) -> bool :
275+ def __contains__ (self , item : Union [object , str ]) -> bool :
263276 if isinstance (item , str ):
264277 return item in self ._dict
265278 return list .__contains__ (self , item )
0 commit comments