Skip to content

Commit f4ea5d8

Browse files
committed
Import lightning.pytorch instead of pytorch_lightning
1 parent 585e7a6 commit f4ea5d8

File tree

4 files changed

+11
-13
lines changed

4 files changed

+11
-13
lines changed

src/nn_core/callbacks.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from pathlib import Path
33
from typing import Any, Dict, Optional
44

5-
import pytorch_lightning as pl
5+
import lightning.pytorch as pl
6+
from lightning.pytorch import Callback, Trainer
67
from omegaconf import DictConfig
7-
from pytorch_lightning import Callback, Trainer
88

99
from nn_core.model_logging import NNLogger
1010
from nn_core.resume import parse_restore
@@ -57,4 +57,3 @@ def on_save_checkpoint(
5757
metadata = getattr(trainer.datamodule, "metadata", None)
5858
if metadata is not None:
5959
checkpoint[METADATA_KEY] = metadata
60-

src/nn_core/common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import dotenv
77
import numpy as np
88
from hydra.core.hydra_config import HydraConfig
9+
from lightning.pytorch import seed_everything
910
from omegaconf import DictConfig
10-
from pytorch_lightning import seed_everything
1111
from rich.prompt import Prompt
1212

1313
pylogger = logging.getLogger(__name__)

src/nn_core/model_logging.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
from typing import Any, Dict, Optional, Union
77

88
import hydra
9-
import pytorch_lightning
9+
from lightning.pytorch import LightningModule, Trainer
10+
from lightning.pytorch.callbacks import ModelCheckpoint
11+
from lightning.pytorch.loggers.logger import Logger
1012
from omegaconf import DictConfig, OmegaConf
11-
from pytorch_lightning import LightningModule, Trainer
12-
from pytorch_lightning.callbacks import ModelCheckpoint
13-
from pytorch_lightning.loggers.logger import Logger
1413

1514
from nn_core.common import PROJECT_ROOT
1615

@@ -115,7 +114,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
115114
116115
This method logs metrics as as soon as it received them. If you want to aggregate
117116
metrics for one specific `step`, use the
118-
:meth:`~pytorch_lightning.loggers.base.Logger.agg_and_log_metrics` method.
117+
:meth:`~lightning.pytorch.loggers.base.Logger.agg_and_log_metrics` method.
119118
120119
Args:
121120
metrics: Dictionary with metric names as keys and measured quantities as values
@@ -167,7 +166,7 @@ def run_dir(self) -> str:
167166

168167
def log_configuration(
169168
self,
170-
model: pytorch_lightning.LightningModule,
169+
model: LightningModule,
171170
cfg: Union[Dict[str, Any], argparse.Namespace, DictConfig] = None,
172171
*args,
173172
**kwargs,

src/nn_core/serialization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from pathlib import Path
1111
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
1212

13-
import pytorch_lightning as pl
13+
import lightning.pytorch as pl
1414
import torch
15-
from pytorch_lightning.core.saving import _load_state
16-
from pytorch_lightning.plugins import TorchCheckpointIO
15+
from lightning.pytorch.core.saving import _load_state
16+
from lightning.pytorch.plugins import TorchCheckpointIO
1717

1818
METADATA_KEY: str = "metadata"
1919

0 commit comments

Comments
 (0)