Skip to content

Commit 5da2e43

Browse files
committed
apply fixes
1 parent b373a76 commit 5da2e43

File tree

5 files changed

+30
-28
lines changed

5 files changed

+30
-28
lines changed

src/lightning/pytorch/loops/evaluation_loop.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
import shutil
1616
import sys
1717
from collections import ChainMap, OrderedDict, defaultdict
18+
from collections.abc import Iterable, Iterator
1819
from dataclasses import dataclass
19-
from typing import Any, DefaultDict, Iterable, Iterator, List, Optional, Tuple, Union
20+
from typing import Any, Optional, Union
2021

2122
from lightning_utilities.core.apply_func import apply_to_collection
2223
from torch import Tensor
@@ -67,17 +68,17 @@ def __init__(
6768
self.verbose = verbose
6869
self.inference_mode = inference_mode
6970
self.batch_progress = _BatchProgress() # across dataloaders
70-
self._max_batches: List[Union[int, float]] = []
71+
self._max_batches: list[Union[int, float]] = []
7172

7273
self._results = _ResultCollection(training=False)
73-
self._logged_outputs: List[_OUT_DICT] = []
74+
self._logged_outputs: list[_OUT_DICT] = []
7475
self._has_run: bool = False
7576
self._trainer_fn = trainer_fn
7677
self._stage = stage
7778
self._data_source = _DataLoaderSource(None, f"{stage.dataloader_prefix}_dataloader")
7879
self._combined_loader: Optional[CombinedLoader] = None
7980
self._data_fetcher: Optional[_DataFetcher] = None
80-
self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int)
81+
self._seen_batches_per_dataloader: defaultdict[int, int] = defaultdict(int)
8182
self._last_val_dl_reload_epoch = float("-inf")
8283
self._module_mode = _ModuleMode()
8384
self._restart_stage = RestartStage.NONE
@@ -90,7 +91,7 @@ def num_dataloaders(self) -> int:
9091
return len(combined_loader.flattened)
9192

9293
@property
93-
def max_batches(self) -> List[Union[int, float]]:
94+
def max_batches(self) -> list[Union[int, float]]:
9495
"""The max number of batches to run per dataloader."""
9596
max_batches = self._max_batches
9697
if not self.trainer.sanity_checking:
@@ -114,7 +115,7 @@ def _is_sequential(self) -> bool:
114115
return self._combined_loader._mode == "sequential"
115116

116117
@_no_grad_context
117-
def run(self) -> List[_OUT_DICT]:
118+
def run(self) -> list[_OUT_DICT]:
118119
self.setup_data()
119120
if self.skip:
120121
return []
@@ -280,7 +281,7 @@ def on_run_start(self) -> None:
280281
self._on_evaluation_start()
281282
self._on_evaluation_epoch_start()
282283

283-
def on_run_end(self) -> List[_OUT_DICT]:
284+
def on_run_end(self) -> list[_OUT_DICT]:
284285
"""Runs the ``_on_evaluation_epoch_end`` hook."""
285286
# if `done` returned True before any iterations were done, this won't have been called in `on_advance_end`
286287
self.trainer._logger_connector.epoch_end_reached()
@@ -508,7 +509,7 @@ def _verify_dataloader_idx_requirement(self) -> None:
508509
)
509510

510511
@staticmethod
511-
def _get_keys(data: dict) -> Iterable[Tuple[str, ...]]:
512+
def _get_keys(data: dict) -> Iterable[tuple[str, ...]]:
512513
for k, v in data.items():
513514
if isinstance(v, dict):
514515
for new_key in apply_to_collection(v, dict, _EvaluationLoop._get_keys):
@@ -527,7 +528,7 @@ def _find_value(data: dict, target: Iterable[str]) -> Optional[Any]:
527528
return _EvaluationLoop._find_value(result, rest)
528529

529530
@staticmethod
530-
def _print_results(results: List[_OUT_DICT], stage: str) -> None:
531+
def _print_results(results: list[_OUT_DICT], stage: str) -> None:
531532
# remove the dl idx suffix
532533
results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results]
533534
metrics_paths = {k for keys in apply_to_collection(results, dict, _EvaluationLoop._get_keys) for k in keys}
@@ -544,7 +545,7 @@ def _print_results(results: List[_OUT_DICT], stage: str) -> None:
544545
term_size = shutil.get_terminal_size(fallback=(120, 30)).columns or 120
545546
max_length = int(min(max(len(max(metrics_strs, key=len)), len(max(headers, key=len)), 25), term_size / 2))
546547

547-
rows: List[List[Any]] = [[] for _ in metrics_paths]
548+
rows: list[list[Any]] = [[] for _ in metrics_paths]
548549

549550
for result in results:
550551
for metric, row in zip(metrics_paths, rows):

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import logging
1515
from dataclasses import dataclass
16-
from typing import Any, Dict, List, Optional, Union
16+
from typing import Any, Optional, Union
1717

1818
import torch
1919
from typing_extensions import override
@@ -104,7 +104,7 @@ def __init__(
104104

105105
self._data_source = _DataLoaderSource(None, "train_dataloader")
106106
self._combined_loader: Optional[CombinedLoader] = None
107-
self._combined_loader_states_to_load: List[Dict[str, Any]] = []
107+
self._combined_loader_states_to_load: list[dict[str, Any]] = []
108108
self._data_fetcher: Optional[_DataFetcher] = None
109109
self._last_train_dl_reload_epoch = float("-inf")
110110
self._restart_stage = RestartStage.NONE
@@ -504,14 +504,14 @@ def teardown(self) -> None:
504504
self.epoch_loop.teardown()
505505

506506
@override
507-
def on_save_checkpoint(self) -> Dict:
507+
def on_save_checkpoint(self) -> dict:
508508
state_dict = super().on_save_checkpoint()
509509
if self._combined_loader is not None and (loader_states := self._combined_loader._state_dicts()):
510510
state_dict["combined_loader"] = loader_states
511511
return state_dict
512512

513513
@override
514-
def on_load_checkpoint(self, state_dict: Dict) -> None:
514+
def on_load_checkpoint(self, state_dict: dict) -> None:
515515
self._combined_loader_states_to_load = state_dict.get("combined_loader", [])
516516
super().on_load_checkpoint(state_dict)
517517

src/lightning/pytorch/loops/loop.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Dict, Optional
14+
from typing import Optional
1515

1616
import lightning.pytorch as pl
1717
from lightning.pytorch.loops.progress import _BaseProgress
@@ -41,7 +41,7 @@ def restarting(self, restarting: bool) -> None:
4141
def reset_restart_stage(self) -> None:
4242
pass
4343

44-
def on_save_checkpoint(self) -> Dict:
44+
def on_save_checkpoint(self) -> dict:
4545
"""Called when saving a model checkpoint, use to persist loop state.
4646
4747
Returns:
@@ -50,10 +50,10 @@ def on_save_checkpoint(self) -> Dict:
5050
"""
5151
return {}
5252

53-
def on_load_checkpoint(self, state_dict: Dict) -> None:
53+
def on_load_checkpoint(self, state_dict: dict) -> None:
5454
"""Called when loading a model checkpoint, use to reload loop state."""
5555

56-
def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Dict:
56+
def state_dict(self, destination: Optional[dict] = None, prefix: str = "") -> dict:
5757
"""The state dict is determined by the state and progress of this loop and all its children.
5858
5959
Args:
@@ -77,7 +77,7 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di
7777

7878
def load_state_dict(
7979
self,
80-
state_dict: Dict,
80+
state_dict: dict,
8181
prefix: str = "",
8282
) -> None:
8383
"""Loads the state of this loop and all its children."""
@@ -88,7 +88,7 @@ def load_state_dict(
8888
self.restarting = True
8989
self._loaded_from_state_dict = True
9090

91-
def _load_from_state_dict(self, state_dict: Dict, prefix: str) -> None:
91+
def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None:
9292
for k, v in self.__dict__.items():
9393
key = prefix + k
9494
if key not in state_dict:

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import math
1515
from collections import OrderedDict
1616
from dataclasses import dataclass
17-
from typing import Any, Dict, Optional, Union
17+
from typing import Any, Optional, Union
1818

1919
from typing_extensions import override
2020

@@ -390,13 +390,13 @@ def teardown(self) -> None:
390390
self.val_loop.teardown()
391391

392392
@override
393-
def on_save_checkpoint(self) -> Dict:
393+
def on_save_checkpoint(self) -> dict:
394394
state_dict = super().on_save_checkpoint()
395395
state_dict["_batches_that_stepped"] = self._batches_that_stepped
396396
return state_dict
397397

398398
@override
399-
def on_load_checkpoint(self, state_dict: Dict) -> None:
399+
def on_load_checkpoint(self, state_dict: dict) -> None:
400400
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)
401401

402402
def _accumulated_batches_reached(self) -> bool:

tests/tests_pytorch/loops/test_loops.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from collections.abc import Iterator
1516
from copy import deepcopy
1617
from dataclasses import dataclass
17-
from typing import Any, Dict, Iterator
18+
from typing import Any
1819
from unittest.mock import ANY, Mock
1920

2021
import pytest
@@ -87,10 +88,10 @@ def advance(self) -> None:
8788

8889
self.outputs.append(value)
8990

90-
def state_dict(self) -> Dict:
91+
def state_dict(self) -> dict:
9192
return {"iteration_count": self.iteration_count, "outputs": self.outputs}
9293

93-
def load_state_dict(self, state_dict: Dict) -> None:
94+
def load_state_dict(self, state_dict: dict) -> None:
9495
self.iteration_count = state_dict["iteration_count"]
9596
self.outputs = state_dict["outputs"]
9697

@@ -140,10 +141,10 @@ def advance(self) -> None:
140141
return
141142
loop.run()
142143

143-
def on_save_checkpoint(self) -> Dict:
144+
def on_save_checkpoint(self) -> dict:
144145
return {"a": self.a}
145146

146-
def on_load_checkpoint(self, state_dict: Dict) -> None:
147+
def on_load_checkpoint(self, state_dict: dict) -> None:
147148
self.a = state_dict["a"]
148149

149150
trainer = Trainer()

0 commit comments

Comments
 (0)