1515import shutil
1616import sys
1717from collections import ChainMap , OrderedDict , defaultdict
18+ from collections .abc import Iterable , Iterator
1819from dataclasses import dataclass
19- from typing import Any , DefaultDict , Iterable , Iterator , List , Optional , Tuple , Union
20+ from typing import Any , Optional , Union
2021
2122from lightning_utilities .core .apply_func import apply_to_collection
2223from torch import Tensor
@@ -67,17 +68,17 @@ def __init__(
6768 self .verbose = verbose
6869 self .inference_mode = inference_mode
6970 self .batch_progress = _BatchProgress () # across dataloaders
70- self ._max_batches : List [Union [int , float ]] = []
71+ self ._max_batches : list [Union [int , float ]] = []
7172
7273 self ._results = _ResultCollection (training = False )
73- self ._logged_outputs : List [_OUT_DICT ] = []
74+ self ._logged_outputs : list [_OUT_DICT ] = []
7475 self ._has_run : bool = False
7576 self ._trainer_fn = trainer_fn
7677 self ._stage = stage
7778 self ._data_source = _DataLoaderSource (None , f"{ stage .dataloader_prefix } _dataloader" )
7879 self ._combined_loader : Optional [CombinedLoader ] = None
7980 self ._data_fetcher : Optional [_DataFetcher ] = None
80- self ._seen_batches_per_dataloader : DefaultDict [int , int ] = defaultdict (int )
81+ self ._seen_batches_per_dataloader : defaultdict [int , int ] = defaultdict (int )
8182 self ._last_val_dl_reload_epoch = float ("-inf" )
8283 self ._module_mode = _ModuleMode ()
8384 self ._restart_stage = RestartStage .NONE
@@ -90,7 +91,7 @@ def num_dataloaders(self) -> int:
9091 return len (combined_loader .flattened )
9192
9293 @property
93- def max_batches (self ) -> List [Union [int , float ]]:
94+ def max_batches (self ) -> list [Union [int , float ]]:
9495 """The max number of batches to run per dataloader."""
9596 max_batches = self ._max_batches
9697 if not self .trainer .sanity_checking :
@@ -114,7 +115,7 @@ def _is_sequential(self) -> bool:
114115 return self ._combined_loader ._mode == "sequential"
115116
116117 @_no_grad_context
117- def run (self ) -> List [_OUT_DICT ]:
118+ def run (self ) -> list [_OUT_DICT ]:
118119 self .setup_data ()
119120 if self .skip :
120121 return []
@@ -280,7 +281,7 @@ def on_run_start(self) -> None:
280281 self ._on_evaluation_start ()
281282 self ._on_evaluation_epoch_start ()
282283
283- def on_run_end (self ) -> List [_OUT_DICT ]:
284+ def on_run_end (self ) -> list [_OUT_DICT ]:
284285 """Runs the ``_on_evaluation_epoch_end`` hook."""
285286 # if `done` returned True before any iterations were done, this won't have been called in `on_advance_end`
286287 self .trainer ._logger_connector .epoch_end_reached ()
@@ -508,7 +509,7 @@ def _verify_dataloader_idx_requirement(self) -> None:
508509 )
509510
510511 @staticmethod
511- def _get_keys (data : dict ) -> Iterable [Tuple [str , ...]]:
512+ def _get_keys (data : dict ) -> Iterable [tuple [str , ...]]:
512513 for k , v in data .items ():
513514 if isinstance (v , dict ):
514515 for new_key in apply_to_collection (v , dict , _EvaluationLoop ._get_keys ):
@@ -527,7 +528,7 @@ def _find_value(data: dict, target: Iterable[str]) -> Optional[Any]:
527528 return _EvaluationLoop ._find_value (result , rest )
528529
529530 @staticmethod
530- def _print_results (results : List [_OUT_DICT ], stage : str ) -> None :
531+ def _print_results (results : list [_OUT_DICT ], stage : str ) -> None :
531532 # remove the dl idx suffix
532533 results = [{k .split ("/dataloader_idx_" )[0 ]: v for k , v in result .items ()} for result in results ]
533534 metrics_paths = {k for keys in apply_to_collection (results , dict , _EvaluationLoop ._get_keys ) for k in keys }
@@ -544,7 +545,7 @@ def _print_results(results: List[_OUT_DICT], stage: str) -> None:
544545 term_size = shutil .get_terminal_size (fallback = (120 , 30 )).columns or 120
545546 max_length = int (min (max (len (max (metrics_strs , key = len )), len (max (headers , key = len )), 25 ), term_size / 2 ))
546547
547- rows : List [ List [Any ]] = [[] for _ in metrics_paths ]
548+ rows : list [ list [Any ]] = [[] for _ in metrics_paths ]
548549
549550 for result in results :
550551 for metric , row in zip (metrics_paths , rows ):
0 commit comments