|
35 | 35 | from torch import Tensor |
36 | 36 | from typing_extensions import override |
37 | 37 |
|
38 | | -import pytorch_lightning as pl |
| 38 | +import lightning.pytorch as pl |
| 39 | +from lightning.pytorch.callbacks import Checkpoint |
| 40 | +from lightning.pytorch.utilities.exceptions import MisconfigurationException |
| 41 | +from lightning.pytorch.utilities.rank_zero import ( |
| 42 | + WarningCache, |
| 43 | + rank_zero_info, |
| 44 | + rank_zero_warn, |
| 45 | +) |
| 46 | +from lightning.pytorch.utilities.types import STEP_OUTPUT |
39 | 47 | from lightning_fabric.utilities.cloud_io import ( |
40 | 48 | _is_dir, |
41 | 49 | _is_local_file_protocol, |
42 | 50 | get_filesystem, |
43 | 51 | ) |
44 | 52 | from lightning_fabric.utilities.types import _PATH |
45 | | -from pytorch_lightning.callbacks import Checkpoint |
46 | | -from pytorch_lightning.utilities.exceptions import MisconfigurationException |
47 | | -from pytorch_lightning.utilities.rank_zero import ( |
48 | | - WarningCache, |
49 | | - rank_zero_info, |
50 | | - rank_zero_warn, |
51 | | -) |
52 | | -from pytorch_lightning.utilities.types import STEP_OUTPUT |
53 | 53 |
|
54 | 54 | log = logging.getLogger(__name__) |
55 | 55 | warning_cache = WarningCache() |
@@ -252,7 +252,7 @@ def __init__( |
252 | 252 | self.best_k_models: dict[str, dict[str, Tensor | dict[str, Tensor]]] = {} |
253 | 253 | self.kth_best_model_path = "" |
254 | 254 | self.best_model_score: Optional[Tensor] = None |
255 | | - self.best_model_metrics: Optional[Dict[str, Tensor]] = None |
| 255 | + self.best_model_metrics: Optional[dict[str, Tensor]] = None |
256 | 256 | self.best_model_path = "" |
257 | 257 | self.last_model_path = "" |
258 | 258 | self._last_checkpoint_saved = "" |
|
0 commit comments