Skip to content

Commit 42169a2

Browse files
johnhenningrohitgr7Carlos Mocholí
authored
Add typing to LightningModule.trainer (#12345)
Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 2de6a9b commit 42169a2

File tree

14 files changed

+71
-76
lines changed

14 files changed

+71
-76
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
9595
torch._C._log_api_usage_once(f"lightning.module.{self.__class__.__name__}")
9696

9797
# pointer to the trainer object
98-
self.trainer = None
98+
self.trainer: Optional["pl.Trainer"] = None
9999

100100
self._use_amp: bool = False
101101

pytorch_lightning/core/optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def _init_optimizers_and_lr_schedulers(
176176
model: "pl.LightningModule",
177177
) -> Tuple[List[Optimizer], List[LRSchedulerConfig], List[int]]:
178178
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""
179+
assert model.trainer is not None
179180
optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)
180181

181182
if optim_conf is None:

pytorch_lightning/overrides/base.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,10 @@ def on_post_move_to_device(self) -> None:
5757

5858

5959
class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
60-
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]):
61-
"""
62-
Wraps the user's LightningModule and redirects the forward call to the appropriate
63-
method, either ``training_step``, ``validation_step`` or ``test_step``.
64-
If the LightningModule is in none of the states `training`, `testing` or `validation`,
65-
the inputs will be redirected to the
66-
:meth:`~pytorch_lightning.core.lightning.LightningModule.predict` method.
60+
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
61+
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
62+
``training_step``, ``validation_step``, ``test_step``, or ``predict_step``.
63+
6764
Inheriting classes may also modify the inputs or outputs of forward.
6865
6966
Args:
@@ -77,28 +74,26 @@ def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionMod
7774
self._ddp_params_and_buffers_to_ignore = [f"module.{p}" for p in _ddp_params_and_buffers_to_ignore]
7875

7976
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
80-
lightning_module = unwrap_lightning_module(self.module)
81-
trainer = lightning_module.trainer
82-
83-
if trainer and trainer.training:
84-
output = self.module.training_step(*inputs, **kwargs)
85-
86-
# In manual_optimization, we need to prevent DDP reducer as
87-
# it is done manually in `LightningModule.manual_backward`
88-
# `require_backward_grad_sync` will be reset in the
89-
# ddp_strategy `post_training_step` hook
90-
if not lightning_module.automatic_optimization:
91-
trainer.model.require_backward_grad_sync = False
92-
elif trainer and trainer.testing:
93-
output = self.module.test_step(*inputs, **kwargs)
94-
elif trainer and (trainer.sanity_checking or trainer.validating):
95-
output = self.module.validation_step(*inputs, **kwargs)
96-
elif trainer and trainer.predicting:
97-
output = self.module.predict_step(*inputs, **kwargs)
98-
else:
99-
output = self.module(*inputs, **kwargs)
100-
101-
return output
77+
pl_module = unwrap_lightning_module(self.module)
78+
trainer = pl_module.trainer
79+
80+
if trainer is not None:
81+
if trainer.training:
82+
output = self.module.training_step(*inputs, **kwargs)
83+
# In manual_optimization, we need to prevent DDP reducer as
84+
# it is done manually in `LightningModule.manual_backward`
85+
# `require_backward_grad_sync` will be reset in the
86+
# ddp_strategy `post_training_step` hook
87+
if not pl_module.automatic_optimization:
88+
trainer.model.require_backward_grad_sync = False # type: ignore[assignment]
89+
return output
90+
if trainer.testing:
91+
return self.module.test_step(*inputs, **kwargs)
92+
if trainer.sanity_checking or trainer.validating:
93+
return self.module.validation_step(*inputs, **kwargs)
94+
if trainer.predicting:
95+
return self.module.predict_step(*inputs, **kwargs)
96+
return self.module(*inputs, **kwargs)
10297

10398
def on_post_move_to_device(self) -> None:
10499
pass

pytorch_lightning/overrides/data_parallel.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
# limitations under the License.
1414
import numbers
1515
import warnings
16-
from typing import Any, Union
16+
from typing import Any, cast, Union
1717

1818
import torch
1919

2020
import pytorch_lightning as pl
21-
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
21+
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
2222
from pytorch_lightning.utilities.apply_func import apply_to_collection
2323
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
2424

@@ -36,10 +36,11 @@ def _ignore_scalar_return_in_dp() -> None:
3636

3737
class LightningParallelModule(_LightningModuleWrapperBase):
3838
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
39-
``training_step``, ``validation_step``, ``test_step`` or ``predict``. This class is used in combination with
40-
:class:`~torch.nn.parallel.DataParallel` as shown in the example. It also takes care of converting Python
41-
scalars to Tensors and un-squeezes 0-dimensional Tensors as it is required by
42-
:class:`~torch.nn.parallel.DataParallel`.
39+
``training_step``, ``validation_step``, ``test_step``, or ``predict_step``.
40+
41+
This class is used in combination with :class:`~torch.nn.parallel.DataParallel` as shown in the example.
42+
It also takes care of converting Python scalars to Tensors and un-squeezes 0-dimensional Tensors as it is required
43+
by :class:`~torch.nn.parallel.DataParallel`.
4344
4445
Example:
4546
@@ -53,7 +54,7 @@ class LightningParallelModule(_LightningModuleWrapperBase):
5354
pl_module: the model to wrap
5455
"""
5556

56-
def __init__(self, pl_module: "pl.LightningModule") -> None:
57+
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
5758
super().__init__(pl_module)
5859
_ignore_scalar_return_in_dp()
5960

@@ -63,7 +64,8 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
6364
output = super().forward(*inputs, **kwargs)
6465

6566
def output_transform(data: Any) -> Any:
66-
data = python_scalar_to_tensor(data, self.module.device)
67+
device = cast(torch.device, self.module.device)
68+
data = python_scalar_to_tensor(data, device)
6769
data = unsqueeze_scalar_tensor(data)
6870
return data
6971

pytorch_lightning/overrides/distributed.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import itertools
15-
from typing import Any, cast, Iterator, List, Sized, Union
15+
from typing import Any, cast, Iterable, Iterator, List, Sized, Union
1616

1717
import torch
1818
from torch import Tensor
1919
from torch.nn.parallel import DistributedDataParallel
2020
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
2121

22-
import pytorch_lightning as pl
2322
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
2423
from pytorch_lightning.utilities import rank_zero_deprecation
2524

2625

2726
class LightningDistributedModule(_LightningModuleWrapperBase):
28-
def __init__(self, pl_module: "pl.LightningModule") -> None:
29-
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
30-
``training_step``, ``validation_step``, ``test_step`` or ``predict``. This class is used in combination
31-
with :class:`~torch.nn.parallel.DistributedDataParallel` as shown in the example.
32-
33-
Example:
34-
35-
ddp_model = torch.nn.parallel.DistributedDataParallel(
36-
module=LightningDistributedModule(lightning_module),
37-
device_ids=[local_rank],
38-
...
39-
)
40-
41-
Args:
42-
pl_module: the model to wrap
43-
"""
44-
super().__init__(pl_module)
27+
...
4528

4629

4730
def _find_tensors(
@@ -164,5 +147,5 @@ def batch_size(self) -> int:
164147
return self._sampler.batch_size
165148

166149
@property
167-
def sampler(self) -> Sampler:
150+
def sampler(self) -> Union[Sampler, Iterable]:
168151
return self._sampler.sampler

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def backward(
6969
closure_loss: the loss value obtained from the closure
7070
optimizer: current optimizer being used. ``None`` if using manual optimization
7171
"""
72+
assert model.trainer is not None
7273
opt = optimizer or model.trainer.optimizers
7374
with amp.scale_loss(closure_loss, opt) as closure_loss:
7475
super().backward(model, closure_loss, optimizer, *args, **kwargs)

pytorch_lightning/plugins/precision/deepspeed.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any
4646
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
4747
" the backward logic internally."
4848
)
49+
assert model.trainer is not None
4950
deepspeed_engine: DeepSpeedEngine = model.trainer.model
5051
deepspeed_engine.backward(closure_loss, *args, **kwargs)
5152

@@ -75,7 +76,12 @@ def optimizer_step(
7576
"Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
7677
)
7778
# DeepSpeed handles the optimizer step internally
78-
deepspeed_engine = model.trainer.model if isinstance(model, pl.LightningModule) else model
79+
deepspeed_engine: DeepSpeedEngine
80+
if isinstance(model, pl.LightningModule):
81+
assert model.trainer is not None
82+
deepspeed_engine = model.trainer.model
83+
else:
84+
deepspeed_engine = model
7985
return deepspeed_engine.step(**kwargs)
8086

8187
def clip_gradients(

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Ten
5555
model: the model to be optimized
5656
closure_loss: the loss value obtained from the closure
5757
"""
58+
assert model.trainer is not None
5859
model.trainer._call_callback_hooks("on_before_backward", closure_loss)
5960
model.trainer._call_lightning_module_hook("on_before_backward", closure_loss)
6061
return closure_loss
@@ -89,6 +90,7 @@ def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Te
8990
"""
9091
# once backward has been applied, release graph
9192
closure_loss = closure_loss.detach()
93+
assert model.trainer is not None
9294
model.trainer._call_callback_hooks("on_after_backward")
9395
model.trainer._call_lightning_module_hook("on_after_backward")
9496
return closure_loss

pytorch_lightning/strategies/bagua.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
from torch.nn import Module
77

88
import pytorch_lightning as pl
9-
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module
9+
from pytorch_lightning.overrides.base import (
10+
_LightningModuleWrapperBase,
11+
_LightningPrecisionModuleWrapperBase,
12+
unwrap_lightning_module,
13+
)
1014
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
1115
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
1216
from pytorch_lightning.plugins.precision import PrecisionPlugin
@@ -32,7 +36,7 @@
3236

3337

3438
class LightningBaguaModule(_LightningModuleWrapperBase):
35-
def __init__(self, pl_module: "pl.LightningModule") -> None:
39+
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
3640
super().__init__(pl_module)
3741
# Bagua use `bagua_module_name` to distinguish different modules
3842
self._bagua_module_name = f"{pl_module.__class__.__name__}{id(pl_module)}"
@@ -161,6 +165,7 @@ def configure_ddp(self) -> None:
161165
self._model = self._setup_model(model)
162166

163167
# start the background communication for async algorithm
168+
assert self.lightning_module.trainer is not None
164169
if self.lightning_module.trainer.training and self._bagua_algorithm == "async":
165170
self.model.bagua_algorithm.resume(self.model) # type: ignore
166171

@@ -188,6 +193,7 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
188193

189194
def teardown(self) -> None:
190195
# abort the background communication for async algorithm
196+
assert self.lightning_module.trainer is not None
191197
if self.lightning_module.trainer.training and self._bagua_algorithm == "async":
192198
self.model.bagua_algorithm.abort(self.model) # type: ignore
193199

pytorch_lightning/strategies/deepspeed.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
import pytorch_lightning as pl
2929
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers
30-
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
30+
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
3131
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3232
from pytorch_lightning.plugins.precision import PrecisionPlugin
3333
from pytorch_lightning.strategies.ddp import DDPStrategy
@@ -67,7 +67,9 @@ def remove_module_hooks(model: torch.nn.Module) -> None:
6767

6868

6969
class LightningDeepSpeedModule(_LightningModuleWrapperBase):
70-
def __init__(self, pl_module: "pl.LightningModule", precision: int) -> None:
70+
def __init__(
71+
self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: int
72+
) -> None:
7173
super().__init__(pl_module)
7274
self.precision = precision
7375

0 commit comments

Comments
 (0)